Remove LLVM dependency on mlir::Module and instead check Traits.
authorTres Popp <tpopp@google.com>
Mon, 16 Dec 2019 09:35:03 +0000 (01:35 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 16 Dec 2019 09:45:44 +0000 (01:45 -0800)
PiperOrigin-RevId: 285724678

mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/include/mlir/Target/NVVMIR.h
mlir/include/mlir/Target/ROCDLIR.h
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

index 83c30e6..5332a74 100644 (file)
@@ -198,6 +198,10 @@ Value *createGlobalString(Location loc, OpBuilder &builder, StringRef name,
                           StringRef value, LLVM::Linkage linkage,
                           LLVM::LLVMDialect *llvmDialect);
 
+/// LLVM requires some operations to be inside of a Module operation. This
+/// function confirms that the Operation has the desired properties.
+bool satisfiesLLVMModule(Operation *op);
+
 } // end namespace LLVM
 } // end namespace mlir
 
index b957c82..2889012 100644 (file)
@@ -23,6 +23,7 @@
 #ifndef MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
 #define MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
 
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Block.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/Value.h"
@@ -50,7 +51,9 @@ class LLVMFuncOp;
 class ModuleTranslation {
 public:
   template <typename T = ModuleTranslation>
-  static std::unique_ptr<llvm::Module> translateModule(ModuleOp m) {
+  static std::unique_ptr<llvm::Module> translateModule(Operation *m) {
+    if (!satisfiesLLVMModule(m))
+      return nullptr;
     if (failed(checkSupportedModuleOps(m)))
       return nullptr;
     auto llvmModule = prepareLLVMModule(m);
@@ -66,23 +69,30 @@ public:
     return std::move(translator.llvmModule);
   }
 
+  /// A helper method to get the single Block in an operation honoring LLVM's
+  /// module requirements.
+  static Block &getModuleBody(Operation *m) { return m->getRegion(0).front(); }
+
 protected:
   // Translate the given MLIR module expressed in MLIR LLVM IR dialect into an
   // LLVM IR module.  The MLIR LLVM IR dialect holds a pointer to an
   // LLVMContext, the LLVM IR module will be created in that context.
-  explicit ModuleTranslation(ModuleOp module) : mlirModule(module) {}
+  explicit ModuleTranslation(Operation *module) : mlirModule(module) {
+    assert(satisfiesLLVMModule(mlirModule) &&
+           "mlirModule should honor LLVM's module semantics.");
+  }
   virtual ~ModuleTranslation() {}
 
   virtual LogicalResult convertOperation(Operation &op,
                                          llvm::IRBuilder<> &builder);
-  static std::unique_ptr<llvm::Module> prepareLLVMModule(ModuleOp m);
+  static std::unique_ptr<llvm::Module> prepareLLVMModule(Operation *m);
 
   template <typename Range>
   SmallVector<llvm::Value *, 8> lookupValues(Range &&values);
 
 private:
   /// Check whether the module contains only supported ops directly in its body.
-  static LogicalResult checkSupportedModuleOps(ModuleOp m);
+  static LogicalResult checkSupportedModuleOps(Operation *m);
 
   LogicalResult convertFunctions();
   void convertGlobals();
@@ -94,7 +104,7 @@ private:
                                   Location loc);
 
   // Original and translated module.
-  ModuleOp mlirModule;
+  Operation *mlirModule;
   std::unique_ptr<llvm::Module> llvmModule;
 
   // Mappings between llvm.mlir.global definitions and corresponding globals.
index 3a4442e..ec9858e 100644 (file)
@@ -30,14 +30,14 @@ class Module;
 } // namespace llvm
 
 namespace mlir {
-class ModuleOp;
+class Operation;
 
-/// Convert the given MLIR module into NVVM IR. This conversion requires the
-/// registration of the LLVM IR dialect and will extract the LLVM context
-/// from the registered LLVM IR dialect.  In case of error, report it
-/// to the error handler registered with the MLIR context, if any (obtained from
+/// Convert the given LLVM-module-like operation into NVVM IR. This conversion
+/// requires the registration of the LLVM IR dialect and will extract the LLVM
+/// context from the registered LLVM IR dialect.  In case of error, report it to
+/// the error handler registered with the MLIR context, if any (obtained from
 /// the MLIR module), and return `nullptr`.
-std::unique_ptr<llvm::Module> translateModuleToNVVMIR(ModuleOp m);
+std::unique_ptr<llvm::Module> translateModuleToNVVMIR(Operation *m);
 
 } // namespace mlir
 
index 6295a1b..fd00e94 100644 (file)
@@ -31,14 +31,14 @@ class Module;
 } // namespace llvm
 
 namespace mlir {
-class ModuleOp;
+class Operation;
 
-/// Convert the given MLIR module into ROCDL IR. This conversion requires the
-/// registration of the LLVM IR dialect and will extract the LLVM context
-/// from the registered LLVM IR dialect.  In case of error, report it
-/// to the error handler registered with the MLIR context, if any (obtained from
+/// Convert the given LLVM-module-like operation into ROCDL IR. This conversion
+/// requires the registration of the LLVM IR dialect and will extract the LLVM
+/// context from the registered LLVM IR dialect.  In case of error, report it to
+/// the error handler registered with the MLIR context, if any (obtained from
 /// the MLIR module), and return `nullptr`.
-std::unique_ptr<llvm::Module> translateModuleToROCDLIR(ModuleOp m);
+std::unique_ptr<llvm::Module> translateModuleToROCDLIR(Operation *m);
 
 } // namespace mlir
 
index d037d2e..9201da2 100644 (file)
@@ -790,9 +790,12 @@ static ParseResult parseUndefOp(OpAsmParser &parser, OperationState &result) {
 //===----------------------------------------------------------------------===//
 
 GlobalOp AddressOfOp::getGlobal() {
-  auto module = getParentOfType<ModuleOp>();
+  Operation *module = getParentOp();
+  while (module && !satisfiesLLVMModule(module))
+    module = module->getParentOp();
   assert(module && "unexpected operation outside of a module");
-  return module.lookupSymbol<LLVM::GlobalOp>(global_name());
+  return dyn_cast_or_null<LLVM::GlobalOp>(
+      mlir::SymbolTable::lookupSymbolIn(module, global_name()));
 }
 
 static void printAddressOfOp(OpAsmPrinter &p, AddressOfOp op) {
@@ -1030,7 +1033,9 @@ static LogicalResult verify(GlobalOp op) {
   if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType()))
     return op.emitOpError(
         "expects type to be a valid element type for an LLVM pointer");
-  if (op.getParentOp() && !isa<ModuleOp>(op.getParentOp()))
+  if (op.getParentOp() &&
+      !(op.getParentOp()->hasTrait<OpTrait::SymbolTable>() &&
+        op.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()))
     return op.emitOpError("must appear at the module level");
 
   if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
@@ -1675,3 +1680,8 @@ Value *mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
       loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
       ArrayRef<Value *>({cst0, cst0}));
 }
+
+bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
+  return op->hasTrait<OpTrait::SymbolTable>() &&
+         op->hasTrait<OpTrait::IsIsolatedFromAbove>();
+}
index 606e91b..728dc86 100644 (file)
@@ -58,7 +58,7 @@ static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
 class ModuleTranslation : public LLVM::ModuleTranslation {
 
 public:
-  explicit ModuleTranslation(ModuleOp module)
+  explicit ModuleTranslation(Operation *module)
       : LLVM::ModuleTranslation(module) {}
   ~ModuleTranslation() override {}
 
@@ -73,7 +73,7 @@ protected:
 };
 } // namespace
 
-std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(ModuleOp m) {
+std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Operation *m) {
   ModuleTranslation translation(m);
   auto llvmModule =
       LLVM::ModuleTranslation::translateModule<ModuleTranslation>(m);
@@ -82,7 +82,8 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(ModuleOp m) {
 
   // Insert the nvvm.annotations kernel so that the NVVM backend recognizes the
   // function as a kernel.
-  for (auto func : m.getOps<LLVM::LLVMFuncOp>()) {
+  for (auto func :
+       ModuleTranslation::getModuleBody(m).getOps<LLVM::LLVMFuncOp>()) {
     if (!gpu::GPUDialect::isKernel(func))
       continue;
 
index dcd4d6c..7b7c368 100644 (file)
@@ -69,7 +69,7 @@ static llvm::Value *createDeviceFunctionCall(llvm::IRBuilder<> &builder,
 class ModuleTranslation : public LLVM::ModuleTranslation {
 
 public:
-  explicit ModuleTranslation(ModuleOp module)
+  explicit ModuleTranslation(Operation *module)
       : LLVM::ModuleTranslation(module) {}
   ~ModuleTranslation() override {}
 
@@ -84,7 +84,7 @@ protected:
 };
 } // namespace
 
-std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(ModuleOp m) {
+std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(Operation *m) {
   ModuleTranslation translation(m);
 
   // lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics)
@@ -94,7 +94,8 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToROCDLIR(ModuleOp m) {
   // foreach GPU kernel
   // 1. Insert AMDGPU_KERNEL calling convention.
   // 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute.
-  for (auto func : m.getOps<LLVM::LLVMFuncOp>()) {
+  for (auto func :
+       ModuleTranslation::getModuleBody(m).getOps<LLVM::LLVMFuncOp>()) {
     if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
       continue;
 
index f985fed..f5f9cca 100644 (file)
@@ -311,7 +311,7 @@ llvm::GlobalVariable::LinkageTypes convertLinkageType(LLVM::Linkage linkage) {
 // Create named global variables that correspond to llvm.mlir.global
 // definitions.
 void ModuleTranslation::convertGlobals() {
-  for (auto op : mlirModule.getOps<LLVM::GlobalOp>()) {
+  for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
     llvm::Type *type = op.getType().getUnderlyingType();
     llvm::Constant *cst = llvm::UndefValue::get(type);
     if (op.getValueOrNull()) {
@@ -470,10 +470,10 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
   return success();
 }
 
-LogicalResult ModuleTranslation::checkSupportedModuleOps(ModuleOp m) {
-  for (Operation &o : m.getBody()->getOperations())
+LogicalResult ModuleTranslation::checkSupportedModuleOps(Operation *m) {
+  for (Operation &o : getModuleBody(m).getOperations())
     if (!isa<LLVM::LLVMFuncOp>(&o) && !isa<LLVM::GlobalOp>(&o) &&
-        !isa<ModuleTerminatorOp>(&o))
+        !o.isKnownTerminator())
       return o.emitOpError("unsupported module-level operation");
   return success();
 }
@@ -481,7 +481,7 @@ LogicalResult ModuleTranslation::checkSupportedModuleOps(ModuleOp m) {
 LogicalResult ModuleTranslation::convertFunctions() {
   // Declare all functions first because there may be function calls that form a
   // call graph with cycles.
-  for (auto function : mlirModule.getOps<LLVMFuncOp>()) {
+  for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
     llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
         function.getName(),
         llvm::cast<llvm::FunctionType>(function.getType().getUnderlyingType()));
@@ -491,7 +491,7 @@ LogicalResult ModuleTranslation::convertFunctions() {
   }
 
   // Convert functions.
-  for (auto function : mlirModule.getOps<LLVMFuncOp>()) {
+  for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
     // Ignore external functions.
     if (function.isExternal())
       continue;
@@ -503,8 +503,9 @@ LogicalResult ModuleTranslation::convertFunctions() {
   return success();
 }
 
-std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(ModuleOp m) {
-  auto *dialect = m.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
+std::unique_ptr<llvm::Module>
+ModuleTranslation::prepareLLVMModule(Operation *m) {
+  auto *dialect = m->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
   assert(dialect && "LLVM dialect must be registered");
 
   auto llvmModule = llvm::CloneModule(dialect->getLLVMModule());