[mlir] Handle different pointer sizes in unranked memref descriptors
authorKrzysztof Drewniak <Krzysztof.Drewniak@amd.com>
Mon, 9 Jan 2023 15:56:22 +0000 (15:56 +0000)
committerKrzysztof Drewniak <Krzysztof.Drewniak@amd.com>
Thu, 9 Feb 2023 19:14:58 +0000 (19:14 +0000)
The code for unranked memref descriptors assumed that
sizeof(!llvm.ptr) == lizeof(!llvm.ptr<N>) for all address spaces N.
This is not always true (ex. the AMDGPU compiler backend has
sizeof(!llvm.ptr) = 64 bits but sizeof(!llvm.ptr<5>) = 32 bits, where
address space 5 is used for stack allocations). While this is merely
an overallocation in the case where a non-0 address space has pointers
smaller than the default, the existing code could cause OOB memory
accesses when sizeof(!llvm.ptr<N>) > sizeof(!llvm.ptr).

So, add an address spaces parameter to computeSizes in order to
partially resolve this class of bugs. Note that the LLVM data layout
in the conversion passes is currently set to "" and not constructed
from the MLIR data layout or some other source, but this could change
in the future.

Depends on D142159

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D141293

mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h
mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
mlir/lib/Conversion/LLVMCommon/Pattern.cpp
mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir

index bf1bd8f..a68e087 100644 (file)
@@ -157,11 +157,11 @@ public:
                                         Type descriptorType);
 
   /// Builds IR extracting the rank from the descriptor
-  Value rank(OpBuilder &builder, Location loc);
+  Value rank(OpBuilder &builder, Location loc) const;
   /// Builds IR setting the rank in the descriptor
   void setRank(OpBuilder &builder, Location loc, Value value);
   /// Builds IR extracting ranked memref descriptor ptr
-  Value memRefDescPtr(OpBuilder &builder, Location loc);
+  Value memRefDescPtr(OpBuilder &builder, Location loc) const;
   /// Builds IR setting ranked memref descriptor ptr
   void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value);
 
@@ -183,10 +183,13 @@ public:
   static unsigned getNumUnpackedValues() { return 2; }
 
   /// Builds IR computing the sizes in bytes (suitable for opaque allocation)
-  /// and appends the corresponding values into `sizes`.
+  /// and appends the corresponding values into `sizes`. `addressSpaces`
+  /// which must have the same length as `values`, is needed to handle layouts
+  /// where sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)).
   static void computeSizes(OpBuilder &builder, Location loc,
                            LLVMTypeConverter &typeConverter,
                            ArrayRef<UnrankedMemRefDescriptor> values,
+                           ArrayRef<unsigned> addressSpaces,
                            SmallVectorImpl<Value> &sizes);
 
   /// TODO: The following accessors don't take alignment rules between elements
index 3523f98..1a5b97e 100644 (file)
@@ -41,7 +41,7 @@ protected:
 
 protected:
   /// Builds IR to extract a value from the struct at position pos
-  Value extractPtr(OpBuilder &builder, Location loc, unsigned pos);
+  Value extractPtr(OpBuilder &builder, Location loc, unsigned pos) const;
   /// Builds IR to set a value in the struct at position pos
   void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr);
 };
index 8ce9fb6..17259e4 100644 (file)
@@ -296,7 +296,7 @@ UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
   Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
   return UnrankedMemRefDescriptor(descriptor);
 }
-Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
+Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) const {
   return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
 }
 void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
@@ -304,7 +304,7 @@ void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
   setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
 }
 Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
-                                              Location loc) {
+                                              Location loc) const {
   return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
 }
 void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
@@ -341,24 +341,24 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
 
 void UnrankedMemRefDescriptor::computeSizes(
     OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
-    ArrayRef<UnrankedMemRefDescriptor> values, SmallVectorImpl<Value> &sizes) {
+    ArrayRef<UnrankedMemRefDescriptor> values, ArrayRef<unsigned> addressSpaces,
+    SmallVectorImpl<Value> &sizes) {
   if (values.empty())
     return;
-
+  assert(values.size() == addressSpaces.size() &&
+         "must provide address space for each descriptor");
   // Cache the index type.
   Type indexType = typeConverter.getIndexType();
 
   // Initialize shared constants.
   Value one = createIndexAttrConstant(builder, loc, indexType, 1);
   Value two = createIndexAttrConstant(builder, loc, indexType, 2);
-  Value pointerSize = createIndexAttrConstant(
-      builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8));
   Value indexSize =
       createIndexAttrConstant(builder, loc, indexType,
                               ceilDiv(typeConverter.getIndexTypeBitwidth(), 8));
 
   sizes.reserve(sizes.size() + values.size());
-  for (UnrankedMemRefDescriptor desc : values) {
+  for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) {
     // Emit IR computing the memory necessary to store the descriptor. This
     // assumes the descriptor to be
     //   { type*, type*, index, index[rank], index[rank] }
@@ -366,6 +366,9 @@ void UnrankedMemRefDescriptor::computeSizes(
     //   2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
     // TODO: consider including the actual size (including eventual padding due
     // to data layout) into the unranked descriptor.
+    Value pointerSize = createIndexAttrConstant(
+        builder, loc, indexType,
+        ceilDiv(typeConverter.getPointerBitwidth(addressSpace), 8));
     Value doublePointerSize =
         builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize);
 
index 69a0172..d3983a3 100644 (file)
@@ -232,18 +232,27 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
          "expected as may original types as operands");
 
   // Find operands of unranked memref type and store them.
-  SmallVector<UnrankedMemRefDescriptor, 4> unrankedMemrefs;
-  for (unsigned i = 0, e = operands.size(); i < e; ++i)
-    if (origTypes[i].isa<UnrankedMemRefType>())
+  SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs;
+  SmallVector<unsigned> unrankedAddressSpaces;
+  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+    if (auto memRefType = origTypes[i].dyn_cast<UnrankedMemRefType>()) {
       unrankedMemrefs.emplace_back(operands[i]);
+      FailureOr<unsigned> addressSpace =
+          getTypeConverter()->getMemRefAddressSpace(memRefType);
+      if (failed(addressSpace))
+        return failure();
+      unrankedAddressSpaces.emplace_back(*addressSpace);
+    }
+  }
 
   if (unrankedMemrefs.empty())
     return success();
 
   // Compute allocation sizes.
-  SmallVector<Value, 4> sizes;
+  SmallVector<Value> sizes;
   UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(),
-                                         unrankedMemrefs, sizes);
+                                         unrankedMemrefs, unrankedAddressSpaces,
+                                         sizes);
 
   // Get frequently used types.
   MLIRContext *context = builder.getContext();
index b5b192d..1cd0bd8 100644 (file)
@@ -23,7 +23,7 @@ StructBuilder::StructBuilder(Value v) : value(v), structType(v.getType()) {
 }
 
 Value StructBuilder::extractPtr(OpBuilder &builder, Location loc,
-                                unsigned pos) {
+                                unsigned pos) const {
   return builder.create<LLVM::ExtractValueOp>(loc, value, pos);
 }
 
index c3ff7d8..5dccb9b 100644 (file)
@@ -1329,7 +1329,7 @@ private:
     targetDesc.setRank(rewriter, loc, resultRank);
     SmallVector<Value, 4> sizes;
     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
-                                           targetDesc, sizes);
+                                           targetDesc, addressSpace, sizes);
     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
         loc, getVoidPtrType(), IntegerType::get(getContext(), 8),
         sizes.front());
index 509a03c..cdd8cd7 100644 (file)
@@ -742,8 +742,6 @@ convertTransferReadToLoads(vector::TransferReadOp op,
   if (failed(warpMatrixInfo))
     return failure();
 
-  Attribute memorySpace =
-      op.getSource().getType().cast<MemRefType>().getMemorySpace();
   bool isLdMatrixCompatible =
       isSharedMemory(op.getSource().getType().cast<MemRefType>()) &&
       nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
index b59c0b6..daa824d 100644 (file)
@@ -122,9 +122,9 @@ func.func @return_var_memref_caller(%arg0: memref<4x3xf32>) {
   // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : index)
   // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2 : index)
   // These sizes may depend on the data layout, not matching specific values.
-  // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
   // CHECK: %[[IDX_SIZE:.*]] = llvm.mlir.constant
 
+  // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
   // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
   // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm.struct<(i64, ptr)>
   // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
@@ -153,13 +153,12 @@ func.func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> attributes
   // CHECK: %[[DESC_2:.*]] = llvm.insertvalue %[[ALLOCA]], %[[DESC_1]][1]
   %0 = memref.cast %arg0: memref<4x3xf32> to memref<*xf32>
 
-
   // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : index)
   // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2 : index)
   // These sizes may depend on the data layout, not matching specific values.
-  // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
   // CHECK: %[[IDX_SIZE:.*]] = llvm.mlir.constant
 
+  // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
   // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
   // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
   // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
index 6987bdc..9624b18 100644 (file)
@@ -408,8 +408,8 @@ func.func @memref_reshape(%input : memref<2x3xf32>, %shape : memref<?xindex>) {
 // Compute size in bytes to allocate result ranked descriptor
 // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : i64
 // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : i64
-// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
 // CHECK: [[INDEX_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
+// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
 // CHECK: [[DOUBLE_PTR_SIZE:%.*]] = llvm.mul [[C2]], [[PTR_SIZE]] : i64
 // CHECK: [[DESC_ALLOC_SIZE:%.*]] = llvm.add [[DOUBLE_PTR_SIZE]], %{{.*}}
 // CHECK: [[UNDERLYING_DESC:%.*]] = llvm.alloca [[DESC_ALLOC_SIZE]] x i8
index e14f70d..36206f1 100644 (file)
@@ -323,8 +323,8 @@ func.func @memref_reshape(%input : memref<2x3xf32>, %shape : memref<?xindex>) {
 // Compute size in bytes to allocate result ranked descriptor
 // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : i64
 // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : i64
-// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
 // CHECK: [[INDEX_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
+// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
 // CHECK: [[DOUBLE_PTR_SIZE:%.*]] = llvm.mul [[C2]], [[PTR_SIZE]] : i64
 // CHECK: [[DESC_ALLOC_SIZE:%.*]] = llvm.add [[DOUBLE_PTR_SIZE]], %{{.*}}
 // CHECK: [[UNDERLYING_DESC:%.*]] = llvm.alloca [[DESC_ALLOC_SIZE]] x i8