LLVM dialect: introduce llvm.addressof to access globals
authorAlex Zinenko <zinenko@google.com>
Mon, 12 Aug 2019 13:10:29 +0000 (06:10 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 12 Aug 2019 13:10:54 +0000 (06:10 -0700)
This instruction is a local counterpart of llvm.global that takes a symbol
reference to a global and produces an SSA value containing the pointer to it.
Used in combination, these two operations allow one to use globals with other
operations expecting SSA values.  At a cost of IR indirection, we make sure the
functions don't implicitly capture the surrounding SSA values and remain
suitable for parallel processing.

PiperOrigin-RevId: 262908622

mlir/g3doc/Dialects/LLVM.md
mlir/include/mlir/LLVMIR/LLVMOps.td
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/LLVMIR/global.mlir
mlir/test/Target/llvmir.mlir

index 20afa32..4ede9f2 100644 (file)
@@ -255,6 +255,27 @@ Selection: `select <condition>, <lhs>, <rhs>`.
 These operations do not have LLVM IR counterparts but are necessary to map LLVM
 IR into MLIR.
 
+#### `llvm.addressof`
+
+Creates an SSA value containing a pointer to a global variable or constant
+defined by `llvm.global`.  The global value can be defined after its first
+referenced.  If the global value is a constant, storing into it is not allowed.
+
+Examples:
+
+```mlir {.mlir}
+func @foo() {
+  // Get the address of a global.
+  %0 = llvm.addressof @const : !llvm<"i32*">
+
+  // Use it as a regular pointer.
+  %1 = llvm.load %0 : !llvm<"i32*">
+}
+
+// Define the global.
+llvm.global @const(42 : i32) : !llvm.i32
+```
+
 #### `llvm.constant`
 
 Unlike LLVM IR, MLIR does not have first-class constant values. Therefore, all
index b626836..80e6284 100644 (file)
@@ -192,7 +192,7 @@ def FCmpPredicate : I64EnumAttr<
     [FCmpPredicateFALSE, FCmpPredicateOEQ, FCmpPredicateOGT, FCmpPredicateOGE,
      FCmpPredicateOLT, FCmpPredicateOLE, FCmpPredicateONE, FCmpPredicateORD,
      FCmpPredicateUEQ, FCmpPredicateUGT, FCmpPredicateUGE, FCmpPredicateULT,
-     FCmpPredicateULE, FCmpPredicateUNE, FCmpPredicateUNO, FCmpPredicateTRUE 
+     FCmpPredicateULE, FCmpPredicateUNE, FCmpPredicateUNO, FCmpPredicateTRUE
     ]> {
   let cppNamespace = "mlir::LLVM";
 
@@ -394,6 +394,32 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
 
 // Pseudo-operations (do not appear in LLVM IR but necessary for the dialect to
 // work correctly).
+def LLVM_AddressOfOp
+    : LLVM_OneResultOp<"addressof">,
+      Arguments<(ins SymbolRefAttr:$global_name)> {
+  let builders = [
+    OpBuilder<"Builder *builder, OperationState *result, LLVMType resType, "
+              "StringRef name, ArrayRef<NamedAttribute> attrs = {}", [{
+      result->addAttribute("global_name", builder->getSymbolRefAttr(name));
+      result->addAttributes(attrs);
+      result->addTypes(resType);}]>,
+
+    OpBuilder<"Builder *builder, OperationState *result, GlobalOp global, "
+              "ArrayRef<NamedAttribute> attrs = {}", [{
+      build(builder, result, global.getType().getPointerTo(), global.sym_name(),
+            attrs);}]>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Return the llvm.global operation that defined the value referenced here.
+    GlobalOp getGlobal();
+  }];
+
+  let printer = "printAddressOfOp(p, *this);";
+  let parser = "return parseAddressOfOp(parser, result);";
+  let verifier = "return ::verify(*this);";
+}
+
 def LLVM_GlobalOp
     : LLVM_ZeroResultOp<"global">,
       Arguments<(ins TypeAttr:$type, UnitAttr:$constant, StrAttr:$sym_name,
index 04651b8..584d2a8 100644 (file)
@@ -89,6 +89,9 @@ private:
   ModuleOp mlirModule;
   std::unique_ptr<llvm::Module> llvmModule;
 
+  // Mappings between llvm.global definitions and corresponding globals.
+  llvm::DenseMap<Operation *, llvm::GlobalValue *> globalsMapping;
+
 protected:
   // Mappings between original and translated values, used for lookups.
   llvm::StringMap<llvm::Function *> functionMapping;
index 378907e..199d401 100644 (file)
@@ -789,6 +789,49 @@ static ParseResult parseUndefOp(OpAsmParser *parser, OperationState *result) {
 }
 
 //===----------------------------------------------------------------------===//
+// Printer, parser and verifier for LLVM::AddressOfOp.
+//===----------------------------------------------------------------------===//
+
+GlobalOp AddressOfOp::getGlobal() {
+  auto module = getParentOfType<ModuleOp>();
+  assert(module && "unexpected operation outside of a module");
+  return module.lookupSymbol<LLVM::GlobalOp>(global_name());
+}
+
+static void printAddressOfOp(OpAsmPrinter *p, AddressOfOp op) {
+  *p << op.getOperationName() << " @" << op.global_name();
+  p->printOptionalAttrDict(op.getAttrs(), {"global_name"});
+  *p << " : " << op.getResult()->getType();
+}
+
+static ParseResult parseAddressOfOp(OpAsmParser *parser,
+                                    OperationState *result) {
+  Attribute symRef;
+  Type type;
+  if (parser->parseAttribute(symRef, "global_name", result->attributes) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(type) ||
+      parser->addTypeToList(type, result->types))
+    return failure();
+
+  if (!symRef.isa<SymbolRefAttr>())
+    return parser->emitError(parser->getNameLoc(), "expected symbol reference");
+  return success();
+}
+
+static LogicalResult verify(AddressOfOp op) {
+  auto global = op.getGlobal();
+  if (!global)
+    return op.emitOpError("must reference a global defined by 'llvm.global'");
+
+  if (global.getType().getPointerTo() != op.getResult()->getType())
+    return op.emitOpError(
+        "the type must be a pointer to the type of the referred global");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // Printing/parsing for LLVM::ConstantOp.
 //===----------------------------------------------------------------------===//
 
index 5e1109b..7a84eae 100644 (file)
@@ -247,6 +247,18 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
     return success();
   }
 
+  // Emit addressof.  We need to look up the global value referenced by the
+  // operation and store it in the MLIR-to-LLVM value mapping.  This does not
+  // emit any LLVM instruction.
+  if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
+    LLVM::GlobalOp global = addressOfOp.getGlobal();
+    // The verifier should not have allowed this.
+    assert(global && "referencing an undefined global");
+
+    valueMapping[addressOfOp.getResult()] = globalsMapping.lookup(global);
+    return success();
+  }
+
   return opInst.emitError("unsupported or non-LLVM operation: ")
          << opInst.getName();
 }
@@ -290,21 +302,23 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) {
 // Create named global variables that correspond to llvm.global definitions.
 void ModuleTranslation::convertGlobals() {
   for (auto op : mlirModule.getOps<LLVM::GlobalOp>()) {
+    llvm::Constant *cst;
+    llvm::Type *type;
     // String attributes are treated separately because they cannot appear as
     // in-function constants and are thus not supported by getLLVMConstant.
     if (auto strAttr = op.value().dyn_cast<StringAttr>()) {
-      llvm::Constant *cst = llvm::ConstantDataArray::getString(
+      cst = llvm::ConstantDataArray::getString(
           llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false);
-      new llvm::GlobalVariable(*llvmModule, cst->getType(), op.constant(),
-                               llvm::GlobalValue::InternalLinkage, cst,
-                               op.sym_name());
-      return;
+      type = cst->getType();
+    } else {
+      type = op.getType().getUnderlyingType();
+      cst = getLLVMConstant(type, op.value(), op.getLoc());
     }
 
-    llvm::Type *type = op.getType().getUnderlyingType();
-    new llvm::GlobalVariable(
-        *llvmModule, type, op.constant(), llvm::GlobalValue::InternalLinkage,
-        getLLVMConstant(type, op.value(), op.getLoc()), op.sym_name());
+    auto *var = new llvm::GlobalVariable(*llvmModule, type, op.constant(),
+                                         llvm::GlobalValue::InternalLinkage,
+                                         cst, op.sym_name());
+    globalsMapping.try_emplace(op, var);
   }
 }
 
index ee46227..974ae19 100644 (file)
@@ -12,6 +12,17 @@ llvm.global constant @string("foobar") : !llvm<"[6 x i8]">
 // CHECK: llvm.global @string_notype("1234567")
 llvm.global @string_notype("1234567")
 
+// CHECK-LABEL: references
+func @references() {
+  // CHECK: llvm.addressof @global : !llvm<"i64*">
+  %0 = llvm.addressof @global : !llvm<"i64*">
+
+  // CHECK: llvm.addressof @string : !llvm<"[6 x i8]*">
+  %1 = llvm.addressof @string : !llvm<"[6 x i8]*">
+
+  llvm.return
+}
+
 // -----
 
 // expected-error @+1 {{op requires attribute 'sym_name'}}
@@ -54,3 +65,36 @@ llvm.global @i64_needs_type(0: i64)
 // expected-error @+1 {{expected zero or one type}}
 llvm.global @more_than_one_type(0) : !llvm.i64, !llvm.i32
 
+// -----
+
+llvm.global @foo(0: i32) : !llvm.i32
+
+func @bar() {
+  // expected-error @+2{{expected ':'}}
+  llvm.addressof @foo
+}
+
+// -----
+
+func @foo() {
+  // The attribute parser will consume the first colon-type, so we put two of
+  // them to trigger the attribute type mismatch error.
+  // expected-error @+1 {{expected symbol reference}}
+  llvm.addressof "foo" : i64 : !llvm<"void ()*">
+}
+
+// -----
+
+func @foo() {
+  // expected-error @+1 {{must reference a global defined by 'llvm.global'}}
+  llvm.addressof @foo : !llvm<"void ()*">
+}
+
+// -----
+
+llvm.global @foo(0: i32) : !llvm.i32
+
+func @bar() {
+  // expected-error @+1 {{the type must be a pointer to the type of the referred global}}
+  llvm.addressof @foo : !llvm<"i64*">
+}
index 6db5cec..36fa128 100644 (file)
@@ -33,6 +33,23 @@ func @empty() {
   llvm.return
 }
 
+// CHECK-LABEL: @global_refs
+func @global_refs() {
+  // Check load from globals.
+  // CHECK: load i32, i32* @i32_global
+  %0 = llvm.addressof @i32_global : !llvm<"i32*">
+  %1 = llvm.load %0 : !llvm<"i32*">
+
+  // Check the contracted form of load from array constants.
+  // CHECK: load i8, i8* getelementptr inbounds ([6 x i8], [6 x i8]* @string_const, i64 0, i64 0)
+  %2 = llvm.addressof @string_const : !llvm<"[6 x i8]*">
+  %c0 = llvm.constant(0 : index) : !llvm.i64
+  %3 = llvm.getelementptr %2[%c0, %c0] : (!llvm<"[6 x i8]*">, !llvm.i64, !llvm.i64) -> !llvm<"i8*">
+  %4 = llvm.load %3 : !llvm<"i8*">
+
+  llvm.return
+}
+
 // CHECK-LABEL: declare void @body(i64)
 func @body(!llvm.i64)