[mlir] Add alignment attribute to memref.global
authorEugene Zhulenev <ezhulenev@google.com>
Thu, 7 Oct 2021 12:49:59 +0000 (05:49 -0700)
committerEugene Zhulenev <ezhulenev@google.com>
Thu, 7 Oct 2021 13:21:57 +0000 (06:21 -0700)
Revived https://reviews.llvm.org/D102435

Add alignment attribute to `memref.global` and propagate it to llvm global in memref->llvm lowering

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D111309

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
mlir/test/Dialect/MemRef/invalid.mlir

index f5b8486..5bce526 100644 (file)
@@ -839,6 +839,9 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> {
     // Private variable with an initial value.
     memref.global "private" @x : memref<2xf32> = dense<0.0,2.0>
 
+    // Private variable with an initial value and an alignment (power of 2).
+    memref.global "private" @x : memref<2xf32> = dense<0.0,2.0> {alignment = 64}
+
     // Declaration of an external variable.
     memref.global "private" @y : memref<4xi32>
 
@@ -855,7 +858,8 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> {
       OptionalAttr<StrAttr>:$sym_visibility,
       MemRefTypeAttr:$type,
       OptionalAttr<AnyAttr>:$initial_value,
-      UnitAttr:$constant
+      UnitAttr:$constant,
+      OptionalAttr<I64Attr>:$alignment
   );
 
   let assemblyFormat = [{
index e43be68..74462ca 100644 (file)
@@ -451,9 +451,11 @@ struct GlobalMemrefOpLowering
         initialValue = elementsAttr.getValue({});
     }
 
+    uint64_t alignment = global.alignment().getValueOr(0);
+
     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
         global, arrayTy, global.constant(), linkage, global.sym_name(),
-        initialValue, /*alignment=*/0, type.getMemorySpaceAsInt());
+        initialValue, alignment, type.getMemorySpaceAsInt());
     if (!global.isExternal() && global.isUninitialized()) {
       Block *blk = new Block();
       newGlobal.getInitializerRegion().push_back(blk);
index 2908e6b..32bd0a2 100644 (file)
@@ -1176,6 +1176,14 @@ static LogicalResult verify(GlobalOp op) {
     }
   }
 
+  if (Optional<uint64_t> alignAttr = op.alignment()) {
+    uint64_t alignment = alignAttr.getValue();
+
+    if (!llvm::isPowerOf2_64(alignment))
+      return op->emitError() << "alignment attribute value " << alignment
+                             << " is not a power of 2";
+  }
+
   // TODO: verify visibility for declarations.
   return success();
 }
index 035251f..39dbd8e 100644 (file)
@@ -48,7 +48,8 @@ memref::GlobalOp GlobalCreator::getGlobalFor(ConstantOp constantOp) {
       /*sym_visibility=*/globalBuilder.getStringAttr("private"),
       /*type=*/typeConverter.convertType(type).cast<MemRefType>(),
       /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
-      /*constant=*/true);
+      /*constant=*/true,
+      /*alignment=*/IntegerAttr());
   symbolTable.insert(global);
   // The symbol table inserts at the end of the module, but globals are a bit
   // nicer if they are at the beginning.
index 8e5f3dd..6068645 100644 (file)
@@ -701,6 +701,10 @@ func @get_gv3_memref() {
   return
 }
 
+// Test scalar memref with an alignment.
+// CHECK: llvm.mlir.global private @gv4(1.000000e+00 : f32) {alignment = 64 : i64} : f32
+memref.global "private" @gv4 : memref<f32> = dense<1.0> {alignment = 64}
+
 // -----
 
 func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
index 8d0a20b..3a05eb1 100644 (file)
@@ -345,6 +345,11 @@ func @mismatched_types() {
 
 // -----
 
+// expected-error @+1 {{alignment attribute value 63 is not a power of 2}}
+memref.global "private" @gv : memref<4xf32> = dense<1.0> { alignment = 63 }
+
+// -----
+
 func @copy_different_shape(%arg0: memref<2xf32>, %arg1: memref<3xf32>) {
   // expected-error @+1 {{op requires the same shape for all operands}}
   memref.copy %arg0, %arg1 : memref<2xf32> to memref<3xf32>