[mlir] Support alignment in LLVM dialect GlobalOp
authorDumitru Potop <dumitru.potop@inria.fr>
Wed, 12 May 2021 06:45:25 +0000 (08:45 +0200)
committerAlex Zinenko <zinenko@google.com>
Wed, 12 May 2021 07:07:20 +0000 (09:07 +0200)
First step in adding alignment as an attribute to MLIR global definitions. Alignment can be specified for global objects in LLVM IR. It can also be specified as a named attribute in the LLVMIR dialect of MLIR. However, this attribute has no standing and is discarded during translation from MLIR to LLVM IR. This patch does two things: First, it adds the attribute to the syntax of the llvm.mlir.global operation, and by doing this it also adds accessors and verifications. The syntax is "align=XX" (with XX being an integer), placed right after the value of the operation. Second, it allows transforming this operation to and from LLVM IR. It is checked whether the value is an integer power of 2.

Reviewed By: ftynse, mehdi_amini

Differential Revision: https://reviews.llvm.org/D101492

14 files changed:
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Dialect/LLVMIR/global.mlir
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Target/LLVMIR/import.ll
mlir/test/Target/LLVMIR/llvmir.mlir

index f334faa..8612a5b 100644 (file)
@@ -139,7 +139,8 @@ private:
           IntegerType::get(builder.getContext(), 8), value.size());
       global = builder.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
                                               LLVM::Linkage::Internal, name,
-                                              builder.getStringAttr(value));
+                                              builder.getStringAttr(value),
+                                              /*alignment=*/0);
     }
 
     // Get the pointer to the first character in the global string.
index f334faa..8612a5b 100644 (file)
@@ -139,7 +139,8 @@ private:
           IntegerType::get(builder.getContext(), 8), value.size());
       global = builder.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
                                               LLVM::Linkage::Internal, name,
-                                              builder.getStringAttr(value));
+                                              builder.getStringAttr(value),
+                                              /*alignment=*/0);
     }
 
     // Get the pointer to the first character in the global string.
index 592bd77..64271d6 100644 (file)
@@ -907,6 +907,7 @@ def LLVM_GlobalOp : LLVM_Op<"mlir.global",
     StrAttr:$sym_name,
     Linkage:$linkage,
     OptionalAttr<AnyAttr>:$value,
+    OptionalAttr<I64Attr>:$alignment,
     DefaultValuedAttr<Confined<I32Attr, [IntNonNegative]>, "0">:$addr_space,
     OptionalAttr<UnnamedAddr>:$unnamed_addr,
     OptionalAttr<StrAttr>:$section
@@ -991,13 +992,30 @@ def LLVM_GlobalOp : LLVM_Op<"mlir.global",
     // By default, "external" linkage is assumed and the global participates in
     // symbol resolution at link-time.
     llvm.mlir.global @glob(0 : f32) : f32
+
+    // Alignment is optional
+    llvm.mlir.global private constant @y(dense<1.0> : tensor<8xf32>) : !llvm.array<8 x f32>
+    ```
+
+    Like global variables in LLVM IR, globals can have an (optional)
+    alignment attribute using keyword `alignment`. The integer value of the
+    alignment must be a positive integer that is a power of 2.
+
+    Examples:
+
+    ```mlir
+    // Alignment is optional
+    llvm.mlir.global private constant @y(dense<1.0> : tensor<8xf32>) { alignment = 32 : i64 } : !llvm.array<8 x f32>
     ```
+
   }];
   let regions = (region AnyRegion:$initializer);
 
   let builders = [
     OpBuilder<(ins "Type":$type, "bool":$isConstant, "Linkage":$linkage,
-      "StringRef":$name, "Attribute":$value, CArg<"unsigned", "0">:$addrSpace,
+      "StringRef":$name, "Attribute":$value,
+      CArg<"uint64_t", "0">:$alignment,
+      CArg<"unsigned", "0">:$addrSpace,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
   ];
 
index 67f699a..912f58d 100644 (file)
@@ -38,7 +38,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
     auto globalOp = rewriter.create<LLVM::GlobalOp>(
         gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
         LLVM::Linkage::Internal, name, /*value=*/Attribute(),
-        gpu::GPUDialect::getWorkgroupAddressSpace());
+        /*alignment=*/0, gpu::GPUDialect::getWorkgroupAddressSpace());
     workgroupBuffers.push_back(globalOp);
   }
 
index 0da9b95..3441792 100644 (file)
@@ -236,7 +236,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
         rewriter.setInsertionPointToStart(module.getBody());
         dstGlobal = rewriter.create<LLVM::GlobalOp>(
             loc, dstGlobalType,
-            /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute());
+            /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(),
+            /*alignment=*/0);
         rewriter.setInsertionPoint(launchOp);
       }
 
index 172f63b..1914812 100644 (file)
@@ -674,7 +674,8 @@ public:
     // Create `llvm.mlir.global` with initializer region containing one block.
     auto global = rewriter.create<LLVM::GlobalOp>(
         UnknownLoc::get(context), structType, /*isConstant=*/true,
-        LLVM::Linkage::External, executionModeInfoName, Attribute());
+        LLVM::Linkage::External, executionModeInfoName, Attribute(),
+        /*alignment=*/0);
     Location loc = global.getLoc();
     Region &region = global.getInitializerRegion();
     Block *block = rewriter.createBlock(&region);
@@ -752,7 +753,8 @@ public:
                        ? LLVM::Linkage::Private
                        : LLVM::Linkage::External;
     rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
-        op, dstType, isConstant, linkage, op.sym_name(), Attribute());
+        op, dstType, isConstant, linkage, op.sym_name(), Attribute(),
+        /*alignment=*/0);
     return success();
   }
 };
index 3949cd4..dd9a061 100644 (file)
@@ -2302,7 +2302,7 @@ struct GlobalMemrefOpLowering
 
     rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
         global, arrayTy, global.constant(), linkage, global.sym_name(),
-        initialValue, type.getMemorySpaceAsInt());
+        initialValue, /*alignment=*/0, type.getMemorySpaceAsInt());
     return success();
   }
 };
index 9c7fb42..3ff9c6e 100644 (file)
@@ -31,6 +31,8 @@
 #include "llvm/Support/Mutex.h"
 #include "llvm/Support/SourceMgr.h"
 
+#include <iostream>
+
 using namespace mlir;
 using namespace mlir::LLVM;
 
@@ -1250,7 +1252,7 @@ static StringRef getUnnamedAddrAttrName() { return "unnamed_addr"; }
 
 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
                      bool isConstant, Linkage linkage, StringRef name,
-                     Attribute value, unsigned addrSpace,
+                     Attribute value, uint64_t alignment, unsigned addrSpace,
                      ArrayRef<NamedAttribute> attrs) {
   result.addAttribute(SymbolTable::getSymbolAttrName(),
                       builder.getStringAttr(name));
@@ -1259,6 +1261,13 @@ void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
     result.addAttribute("constant", builder.getUnitAttr());
   if (value)
     result.addAttribute("value", value);
+
+  // Only add an alignment attribute if the "alignment" input
+  // is different from 0. The value must also be a power of two, but
+  // this is tested in GlobalOp::verify, not here.
+  if (alignment != 0)
+    result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
+
   result.addAttribute(getLinkageAttrName(),
                       builder.getI64IntegerAttr(static_cast<int64_t>(linkage)));
   if (addrSpace != 0)
@@ -1278,6 +1287,9 @@ static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
   if (auto value = op.getValueOrNull())
     p.printAttribute(value);
   p << ')';
+  // Note that the alignment attribute is printed using the
+  // default syntax here, even though it is an inherent attribute
+  // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes)
   p.printOptionalAttrDict(op->getAttrs(),
                           {SymbolTable::getSymbolAttrName(), "type", "constant",
                            "value", getLinkageAttrName(),
@@ -1493,10 +1505,12 @@ static int parseOptionalKeywordAlternative(OpAsmParser &parser,
 }
 
 namespace {
-template <typename Ty> struct EnumTraits {};
+template <typename Ty>
+struct EnumTraits {};
 
 #define REGISTER_ENUM_TYPE(Ty)                                                 \
-  template <> struct EnumTraits<Ty> {                                          \
+  template <>                                                                  \
+  struct EnumTraits<Ty> {                                                      \
     static StringRef stringify(Ty value) { return stringify##Ty(value); }      \
     static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); }         \
   }
@@ -1521,7 +1535,8 @@ static ParseResult parseOptionalLLVMKeyword(OpAsmParser &parser,
 }
 
 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier
-//               `(` attribute? `)` attribute-list? (`:` type)? region?
+//               `(` attribute? `)` align? attribute-list? (`:` type)? region?
+// align     ::= `align` `=` UINT64
 //
 // The type can be omitted for string attributes, in which case it will be
 // inferred from the value of the string as [strlen(value) x i8].
@@ -1648,6 +1663,13 @@ static LogicalResult verify(GlobalOp op) {
     }
   }
 
+  Optional<uint64_t> alignAttr = op.alignment();
+  if (alignAttr.hasValue()) {
+    uint64_t value = alignAttr.getValue();
+    if (!llvm::isPowerOf2_64(value))
+      return op->emitError() << "alignment attribute is not a power of 2";
+  }
+
   return success();
 }
 
@@ -2339,7 +2361,7 @@ Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
   auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size());
   auto global = moduleBuilder.create<LLVM::GlobalOp>(
       loc, type, /*isConstant=*/true, linkage, name,
-      builder.getStringAttr(value));
+      builder.getStringAttr(value), /*alignment=*/0);
 
   // Get the pointer to the first character in the global string.
   Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
index 5b1faee..4a1653a 100644 (file)
@@ -475,9 +475,19 @@ GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
   Type type = processType(GV->getValueType());
   if (!type)
     return nullptr;
-  GlobalOp op = b.create<GlobalOp>(
-      UnknownLoc::get(context), type, GV->isConstant(),
-      convertLinkageFromLLVM(GV->getLinkage()), GV->getName(), valueAttr);
+
+  uint64_t alignment = 0;
+  llvm::MaybeAlign maybeAlign = GV->getAlign();
+  if (maybeAlign.hasValue()) {
+    llvm::Align align = maybeAlign.getValue();
+    alignment = align.value();
+  }
+
+  GlobalOp op =
+      b.create<GlobalOp>(UnknownLoc::get(context), type, GV->isConstant(),
+                         convertLinkageFromLLVM(GV->getLinkage()),
+                         GV->getName(), valueAttr, alignment);
+
   if (GV->hasInitializer() && !valueAttr) {
     Region &r = op.getInitializerRegion();
     currentEntryBlock = b.createBlock(&r);
@@ -492,6 +502,7 @@ GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
         context, convertUnnamedAddrFromLLVM(GV->getUnnamedAddr())));
   if (GV->hasSection())
     op.sectionAttr(b.getStringAttr(GV->getSection()));
+
   return globals[GV] = op;
 }
 
index cc4a62f..24c01f3 100644 (file)
@@ -437,6 +437,10 @@ LogicalResult ModuleTranslation::convertGlobals() {
     if (op.section().hasValue())
       var->setSection(*op.section());
 
+    Optional<uint64_t> alignment = op.alignment();
+    if (alignment.hasValue())
+      var->setAlignment(llvm::MaybeAlign(alignment.getValue()));
+
     globalsMapping.try_emplace(op, var);
   }
 
index a5306c4..90cd8d2 100644 (file)
@@ -9,6 +9,12 @@ llvm.mlir.global constant @default_external_constant(42) : i64
 // CHECK: llvm.mlir.global internal @global(42 : i64) : i64
 llvm.mlir.global internal @global(42 : i64) : i64
 
+// CHECK: llvm.mlir.global private @aligned_global(42 : i64) {aligned = 64 : i64} : i64
+llvm.mlir.global private @aligned_global(42 : i64) {aligned = 64} : i64
+
+// CHECK: llvm.mlir.global private constant @aligned_global_const(42 : i64) {aligned = 32 : i64} : i64
+llvm.mlir.global private constant @aligned_global_const(42 : i64) {aligned = 32} : i64
+
 // CHECK: llvm.mlir.global internal constant @constant(3.700000e+01 : f64) : f32
 llvm.mlir.global internal constant @constant(37.0) : f32
 
index 0908afa..7491aba 100644 (file)
@@ -1,5 +1,10 @@
 // RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics
 
+// expected-error@+1{{alignment attribute is not a power of 2}}
+llvm.mlir.global private @invalid_global_alignment(42 : i64) {alignment = 63} : i64
+
+// -----
+
 // expected-error@+1{{expected llvm.noalias argument attribute to be a unit attribute}}
 func @invalid_noalias(%arg0: i32 {llvm.noalias = 3}) {
   "llvm.return"() : () -> ()
index 027bc29..a8884bd 100644 (file)
@@ -3,9 +3,9 @@
 %struct.t = type {}
 %struct.s = type { %struct.t, i64 }
 
-; CHECK: llvm.mlir.global external @g1() : !llvm.struct<"struct.s", (struct<"struct.t", ()>, i64)>
+; CHECK: llvm.mlir.global external @g1() {alignment = 8 : i64} : !llvm.struct<"struct.s", (struct<"struct.t", ()>, i64)>
 @g1 = external global %struct.s, align 8
-; CHECK: llvm.mlir.global external @g2() : f64
+; CHECK: llvm.mlir.global external @g2() {alignment = 8 : i64} : f64
 @g2 = external global double, align 8
 ; CHECK: llvm.mlir.global internal @g3("string")
 @g3 = internal global [6 x i8] c"string"
 ; CHECK: llvm.mlir.global external @g5() : vector<8xi32>
 @g5 = external global <8 x i32>
 
+; CHECK: llvm.mlir.global private @alig32(42 : i64) {alignment = 32 : i64} : i64
+@alig32 = private global i64 42, align 32
+
+; CHECK: llvm.mlir.global private @alig64(42 : i64) {alignment = 64 : i64} : i64
+@alig64 = private global i64 42, align 64
+
 @g4 = external global i32, align 8
 ; CHECK: llvm.mlir.global internal constant @int_gep() : !llvm.ptr<i32> {
 ; CHECK-DAG:   %[[addr:[0-9]+]] = llvm.mlir.addressof @g4 : !llvm.ptr<i32>
index db40733..f7ad4c8 100644 (file)
@@ -1,5 +1,14 @@
 // RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
 
+// CHECK: @global_aligned32 = private global i64 42, align 32
+"llvm.mlir.global"() ({}) {sym_name = "global_aligned32", type = i64, value = 42 : i64, linkage = 0, alignment = 32} : () -> ()
+
+// CHECK: @global_aligned64 = private global i64 42, align 64
+llvm.mlir.global private @global_aligned64(42 : i64) {alignment = 64 : i64} : i64
+
+// CHECK: @global_aligned64_native = private global i64 42, align 64
+llvm.mlir.global private @global_aligned64_native(42 : i64) { alignment = 64 } : i64
+
 // CHECK: @i32_global = internal global i32 42
 llvm.mlir.global internal @i32_global(42: i32) : i32
 
@@ -1548,3 +1557,4 @@ module {
 // CHECK: ![[PIPELINE_DISABLE_NODE]] = !{!"llvm.loop.pipeline.disable", i1 true}
 // CHECK: ![[II_NODE]] = !{!"llvm.loop.pipeline.initiationinterval", i32 2}
 // CHECK: ![[ACCESS_GROUPS_NODE]] = !{![[GROUP_NODE1]], ![[GROUP_NODE2]]}
+