[mlir][memref] Fix num elements in lowering of memref.alloca op to LLVM
authorFabian Mora <fmora.dev@gmail.com>
Mon, 22 May 2023 16:21:28 +0000 (16:21 +0000)
committerfmorac <fmora.dev@gmail.com>
Mon, 22 May 2023 16:23:00 +0000 (16:23 +0000)
Fixes a mistake in the lowering of memref.alloca to llvm.alloca, as llvm.alloca uses the number of elements to allocate in the stack and not the size in bytes.

Reference:
LLVM IR: https://llvm.org/docs/LangRef.html#alloca-instruction
LLVM MLIR: https://mlir.llvm.org/docs/Dialects/LLVM/#llvmalloca-mlirllvmallocaop

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D150705

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
mlir/lib/Conversion/LLVMCommon/Pattern.cpp
mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir

index 1a362f6..71bb424 100644 (file)
@@ -82,12 +82,12 @@ protected:
   /// Returns the type of a pointer to an element of the memref.
   Type getElementPtrType(MemRefType type) const;
 
-  /// Computes sizes, strides and buffer size in bytes of `memRefType` with
-  /// identity layout. Emits constant ops for the static sizes of `memRefType`,
-  /// and uses `dynamicSizes` for the others. Emits instructions to compute
-  /// strides and buffer size from these sizes.
+  /// Computes sizes, strides and buffer size of `memRefType` with identity
+  /// layout. Emits constant ops for the static sizes of `memRefType`, and uses
+  /// `dynamicSizes` for the others. Emits instructions to compute strides and
+  /// buffer size from these sizes.
   ///
-  /// For example, memref<4x?xf32> emits:
+  /// For example, memref<4x?xf32> with `sizeInBytes = true` emits:
   /// `sizes[0]`   = llvm.mlir.constant(4 : index) : i64
   /// `sizes[1]`   = `dynamicSizes[0]`
   /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
@@ -97,19 +97,27 @@ protected:
   /// %gep         = llvm.getelementptr %nullptr[%size]
   ///                  : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
   /// `sizeBytes`  = llvm.ptrtoint %gep : !llvm.ptr<f32> to i64
+  ///
+  /// If `sizeInBytes = false`, memref<4x?xf32> emits:
+  /// `sizes[0]`   = llvm.mlir.constant(4 : index) : i64
+  /// `sizes[1]`   = `dynamicSizes[0]`
+  /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
+  /// `strides[0]` = `sizes[0]`
+  /// %size        = llvm.mul `sizes[0]`, `sizes[1]` : i64
   void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
                                 ValueRange dynamicSizes,
                                 ConversionPatternRewriter &rewriter,
                                 SmallVectorImpl<Value> &sizes,
-                                SmallVectorImpl<Value> &strides,
-                                Value &sizeBytes) const;
+                                SmallVectorImpl<Value> &strides, Value &size,
+                                bool sizeInBytes = true) const;
 
   /// Computes the size of type in bytes.
   Value getSizeInBytes(Location loc, Type type,
                        ConversionPatternRewriter &rewriter) const;
 
-  /// Computes total number of elements for the given shape.
-  Value getNumElements(Location loc, ArrayRef<Value> shape,
+  /// Computes total number of elements for the given MemRef and dynamicSizes.
+  Value getNumElements(Location loc, MemRefType memRefType,
+                       ValueRange dynamicSizes,
                        ConversionPatternRewriter &rewriter) const;
 
   /// Creates and populates a canonical memref descriptor struct.
index a063623..770f319 100644 (file)
@@ -20,8 +20,10 @@ struct AllocationOpLLVMLowering : public ConvertToLLVMPattern {
   using ConvertToLLVMPattern::getVoidPtrType;
 
   explicit AllocationOpLLVMLowering(StringRef opName,
-                                    LLVMTypeConverter &converter)
-      : ConvertToLLVMPattern(opName, &converter.getContext(), converter) {}
+                                    LLVMTypeConverter &converter,
+                                    PatternBenefit benefit = 1)
+      : ConvertToLLVMPattern(opName, &converter.getContext(), converter,
+                             benefit) {}
 
 protected:
   /// Computes the aligned value for 'input' as follows:
@@ -103,15 +105,20 @@ private:
 /// Lowering for AllocOp and AllocaOp.
 struct AllocLikeOpLLVMLowering : public AllocationOpLLVMLowering {
   explicit AllocLikeOpLLVMLowering(StringRef opName,
-                                   LLVMTypeConverter &converter)
-      : AllocationOpLLVMLowering(opName, converter) {}
+                                   LLVMTypeConverter &converter,
+                                   PatternBenefit benefit = 1)
+      : AllocationOpLLVMLowering(opName, converter, benefit) {}
 
 protected:
   /// Allocates the underlying buffer. Returns the allocated pointer and the
   /// aligned pointer.
   virtual std::tuple<Value, Value>
-  allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
-                 Value sizeBytes, Operation *op) const = 0;
+  allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size,
+                 Operation *op) const = 0;
+
+  /// Sets the flag 'requiresNumElements', specifying the Op requires the number
+  /// of elements instead of the size in bytes.
+  void setRequiresNumElements();
 
 private:
   // An `alloc` is converted into a definition of a memref descriptor value and
@@ -133,6 +140,10 @@ private:
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override;
+
+  // Flag for specifying the Op requires the number of elements instead of the
+  // size in bytes.
+  bool requiresNumElements = false;
 };
 
 } // namespace mlir
index c55a62e..ef099e5 100644 (file)
@@ -121,7 +121,7 @@ Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
 void ConvertToLLVMPattern::getMemRefDescriptorSizes(
     Location loc, MemRefType memRefType, ValueRange dynamicSizes,
     ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
-    SmallVectorImpl<Value> &strides, Value &sizeBytes) const {
+    SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const {
   assert(isConvertibleAndHasIdentityMaps(memRefType) &&
          "layout maps must have been normalized away");
   assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
@@ -143,14 +143,14 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
   for (auto i = memRefType.getRank(); i-- > 0;) {
     strides[i] = runningStride;
 
-    int64_t size = memRefType.getShape()[i];
-    if (size == 0)
+    int64_t staticSize = memRefType.getShape()[i];
+    if (staticSize == 0)
       continue;
     bool useSizeAsStride = stride == 1;
-    if (size == ShapedType::kDynamic)
+    if (staticSize == ShapedType::kDynamic)
       stride = ShapedType::kDynamic;
     if (stride != ShapedType::kDynamic)
-      stride *= size;
+      stride *= staticSize;
 
     if (useSizeAsStride)
       runningStride = sizes[i];
@@ -160,14 +160,17 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
     else
       runningStride = createIndexConstant(rewriter, loc, stride);
   }
-
-  // Buffer size in bytes.
-  Type elementType = typeConverter->convertType(memRefType.getElementType());
-  Type elementPtrType = getTypeConverter()->getPointerType(elementType);
-  Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
-  Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, elementType,
-                                              nullPtr, runningStride);
-  sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
+  if (sizeInBytes) {
+    // Buffer size in bytes.
+    Type elementType = typeConverter->convertType(memRefType.getElementType());
+    Type elementPtrType = getTypeConverter()->getPointerType(elementType);
+    Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
+    Value gepPtr = rewriter.create<LLVM::GEPOp>(
+        loc, elementPtrType, elementType, nullPtr, runningStride);
+    size = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
+  } else {
+    size = runningStride;
+  }
 }
 
 Value ConvertToLLVMPattern::getSizeInBytes(
@@ -186,13 +189,30 @@ Value ConvertToLLVMPattern::getSizeInBytes(
 }
 
 Value ConvertToLLVMPattern::getNumElements(
-    Location loc, ArrayRef<Value> shape,
+    Location loc, MemRefType memRefType, ValueRange dynamicSizes,
     ConversionPatternRewriter &rewriter) const {
+  assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
+             static_cast<ssize_t>(dynamicSizes.size()) &&
+         "dynamicSizes size doesn't match dynamic sizes count in memref shape");
+
+  Value numElements = memRefType.getRank() == 0
+                          ? createIndexConstant(rewriter, loc, 1)
+                          : nullptr;
+  unsigned dynamicIndex = 0;
+
   // Compute the total number of memref elements.
-  Value numElements =
-      shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
-  for (unsigned i = 1, e = shape.size(); i < e; ++i)
-    numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]);
+  for (int64_t staticSize : memRefType.getShape()) {
+    if (numElements) {
+      Value size = staticSize == ShapedType::kDynamic
+                       ? dynamicSizes[dynamicIndex++]
+                       : createIndexConstant(rewriter, loc, staticSize);
+      numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
+    } else {
+      numElements = staticSize == ShapedType::kDynamic
+                        ? dynamicSizes[dynamicIndex++]
+                        : createIndexConstant(rewriter, loc, staticSize);
+    }
+  }
   return numElements;
 }
 
index 2fa4315..b762758 100644 (file)
@@ -156,6 +156,10 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
                              elementPtrType, *getTypeConverter());
 }
 
+void AllocLikeOpLLVMLowering::setRequiresNumElements() {
+  requiresNumElements = true;
+}
+
 LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
     Operation *op, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
@@ -169,13 +173,14 @@ LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
   // zero-dimensional memref, assume a scalar (size 1).
   SmallVector<Value, 4> sizes;
   SmallVector<Value, 4> strides;
-  Value sizeBytes;
+  Value size;
+
   this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
-                                 strides, sizeBytes);
+                                 strides, size, !requiresNumElements);
 
   // Allocate the underlying buffer.
   auto [allocatedPtr, alignedPtr] =
-      this->allocateBuffer(rewriter, loc, sizeBytes, op);
+      this->allocateBuffer(rewriter, loc, size, op);
 
   // Create the MemRef descriptor.
   auto memRefDescriptor = this->createMemRefDescriptor(
index 47c2cdb..24ea1a6 100644 (file)
@@ -85,13 +85,15 @@ private:
 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
   AllocaOpLowering(LLVMTypeConverter &converter)
       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
-                                converter) {}
+                                converter) {
+    setRequiresNumElements();
+  }
 
   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
   /// is set to null for stack allocations. `accessAlignment` is set if
   /// alignment is needed post allocation (for eg. in conjunction with malloc).
   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
-                                          Location loc, Value sizeBytes,
+                                          Location loc, Value size,
                                           Operation *op) const override {
 
     // With alloca, one gets a pointer to the element type right away.
@@ -104,9 +106,9 @@ struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
     auto elementPtrType =
         getTypeConverter()->getPointerType(elementType, addrSpace);
 
-    auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
-        loc, elementPtrType, elementType, sizeBytes,
-        allocaOp.getAlignment().value_or(0));
+    auto allocatedElementPtr =
+        rewriter.create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size,
+                                        allocaOp.getAlignment().value_or(0));
 
     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
   }
index e213bee..520d629 100644 (file)
@@ -91,12 +91,15 @@ gpu.module @test_module {
     %j = arith.constant 16 : index
     gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index, transpose} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
     // CHECK:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
-    // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+    // CHECK:  %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+    // CHECK:  %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+    // CHECK:  %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+    // CHECK:  %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
     // CHECK:  %[[EL1:.*]] = llvm.extractvalue %[[D]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[EL2:.*]] = llvm.extractvalue %[[D]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[EL3:.*]] = llvm.extractvalue %[[D]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[EL4:.*]] = llvm.extractvalue %[[D]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-    // CHECK:  %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+    // CHECK:  %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
     // CHECK:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
     // CHECK:  %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]]   : i64
     // CHECK:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i64
@@ -107,12 +110,15 @@ gpu.module @test_module {
     // CHECK:  llvm.return
 
     // CHECK32:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
-    // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+    // CHECK32:  %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+    // CHECK32:  %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+    // CHECK32:  %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+    // CHECK32:  %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
     // CHECK32:  %[[EL1:.*]] = llvm.extractvalue %[[D]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK32:  %[[EL2:.*]] = llvm.extractvalue %[[D]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK32:  %[[EL3:.*]] = llvm.extractvalue %[[D]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK32:  %[[EL4:.*]] = llvm.extractvalue %[[D]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-    // CHECK32:  %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<2 x i32>, array<2 x i32>)>
+    // CHECK32:  %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<2 x i32>, array<2 x i32>)>
     // CHECK32:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
     // CHECK32:  %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]]   : i32
     // CHECK32:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i32
index b703184..0b9d806 100644 (file)
@@ -86,10 +86,7 @@ func.func @dynamic_alloca(%arg0: index, %arg1: index) -> memref<?x?xf32> {
 //   CHECK-DAG:  %[[N:.*]] = builtin.unrealized_conversion_cast %[[Narg]]
 //  CHECK-NEXT:  %[[st1:.*]] = llvm.mlir.constant(1 : index) : i64
 //  CHECK-NEXT:  %[[num_elems:.*]] = llvm.mul %[[N]], %[[M]] : i64
-//  CHECK-NEXT:  %[[null:.*]] = llvm.mlir.null : !llvm.ptr
-//  CHECK-NEXT:  %[[gep:.*]] = llvm.getelementptr %[[null]][%[[num_elems]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-//  CHECK-NEXT:  %[[sz_bytes:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to i64
-//  CHECK-NEXT:  %[[allocated:.*]] = llvm.alloca %[[sz_bytes]] x f32 : (i64) -> !llvm.ptr
+//  CHECK-NEXT:  %[[allocated:.*]] = llvm.alloca %[[num_elems]] x f32 : (i64) -> !llvm.ptr
 //  CHECK-NEXT:  llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 //  CHECK-NEXT:  llvm.insertvalue %[[allocated]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 //  CHECK-NEXT:  llvm.insertvalue %[[allocated]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
index 0f3f3e4..cc7b210 100644 (file)
@@ -79,10 +79,7 @@ func.func @static_alloca() -> memref<32x18xf32> {
 // CHECK: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : i64
 // CHECK: %[[st2:.*]] = llvm.mlir.constant(1 : index) : i64
 // CHECK: %[[num_elems:.*]] = llvm.mlir.constant(576 : index) : i64
-// CHECK: %[[null:.*]] = llvm.mlir.null : !llvm.ptr
-// CHECK: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[num_elems]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[size_bytes:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to i64
-// CHECK: %[[allocated:.*]] = llvm.alloca %[[size_bytes]] x f32 : (i64) -> !llvm.ptr
+// CHECK: %[[allocated:.*]] = llvm.alloca %[[num_elems]] x f32 : (i64) -> !llvm.ptr
  %0 = memref.alloca() : memref<32x18xf32>
 
  // Test with explicitly specified alignment. llvm.alloca takes care of the