From: Fabian Mora Date: Mon, 22 May 2023 16:21:28 +0000 (+0000) Subject: [mlir][memref] Fix num elements in lowering of memref.alloca op to LLVM X-Git-Tag: upstream/17.0.6~7670 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=041f1abee11c0d5b790776bb26287bdbc7fe058c;p=platform%2Fupstream%2Fllvm.git [mlir][memref] Fix num elements in lowering of memref.alloca op to LLVM Fixes a mistake in the lowering of memref.alloca to llvm.alloca, as llvm.alloca uses the number of elements to allocate in the stack and not the size in bytes. Reference: LLVM IR: https://llvm.org/docs/LangRef.html#alloca-instruction LLVM MLIR: https://mlir.llvm.org/docs/Dialects/LLVM/#llvmalloca-mlirllvmallocaop Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D150705 --- diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index 1a362f6..71bb424 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -82,12 +82,12 @@ protected: /// Returns the type of a pointer to an element of the memref. Type getElementPtrType(MemRefType type) const; - /// Computes sizes, strides and buffer size in bytes of `memRefType` with - /// identity layout. Emits constant ops for the static sizes of `memRefType`, - /// and uses `dynamicSizes` for the others. Emits instructions to compute - /// strides and buffer size from these sizes. + /// Computes sizes, strides and buffer size of `memRefType` with identity + /// layout. Emits constant ops for the static sizes of `memRefType`, and uses + /// `dynamicSizes` for the others. Emits instructions to compute strides and + /// buffer size from these sizes. /// - /// For example, memref<4x?xf32> emits: + /// For example, memref<4x?xf32> with `sizeInBytes = true` emits: /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64 /// `sizes[1]` = `dynamicSizes[0]` /// `strides[1]` = llvm.mlir.constant(1 : index) : i64 @@ -97,19 +97,27 @@ protected: /// %gep = llvm.getelementptr %nullptr[%size] /// : (!llvm.ptr, i64) -> !llvm.ptr /// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr to i64 + /// + /// If `sizeInBytes = false`, memref<4x?xf32> emits: + /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64 + /// `sizes[1]` = `dynamicSizes[0]` + /// `strides[1]` = llvm.mlir.constant(1 : index) : i64 + /// `strides[0]` = `sizes[0]` + /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64 void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl &sizes, - SmallVectorImpl &strides, - Value &sizeBytes) const; + SmallVectorImpl &strides, Value &size, + bool sizeInBytes = true) const; /// Computes the size of type in bytes. Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const; - /// Computes total number of elements for the given shape. - Value getNumElements(Location loc, ArrayRef shape, + /// Computes total number of elements for the given MemRef and dynamicSizes. + Value getNumElements(Location loc, MemRefType memRefType, + ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const; /// Creates and populates a canonical memref descriptor struct. diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h index a063623..770f319 100644 --- a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h +++ b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h @@ -20,8 +20,10 @@ struct AllocationOpLLVMLowering : public ConvertToLLVMPattern { using ConvertToLLVMPattern::getVoidPtrType; explicit AllocationOpLLVMLowering(StringRef opName, - LLVMTypeConverter &converter) - : ConvertToLLVMPattern(opName, &converter.getContext(), converter) {} + LLVMTypeConverter &converter, + PatternBenefit benefit = 1) + : ConvertToLLVMPattern(opName, &converter.getContext(), converter, + benefit) {} protected: /// Computes the aligned value for 'input' as follows: @@ -103,15 +105,20 @@ private: /// Lowering for AllocOp and AllocaOp. struct AllocLikeOpLLVMLowering : public AllocationOpLLVMLowering { explicit AllocLikeOpLLVMLowering(StringRef opName, - LLVMTypeConverter &converter) - : AllocationOpLLVMLowering(opName, converter) {} + LLVMTypeConverter &converter, + PatternBenefit benefit = 1) + : AllocationOpLLVMLowering(opName, converter, benefit) {} protected: /// Allocates the underlying buffer. Returns the allocated pointer and the /// aligned pointer. virtual std::tuple - allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, - Value sizeBytes, Operation *op) const = 0; + allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size, + Operation *op) const = 0; + + /// Sets the flag 'requiresNumElements', specifying the Op requires the number + /// of elements instead of the size in bytes. + void setRequiresNumElements(); private: // An `alloc` is converted into a definition of a memref descriptor value and @@ -133,6 +140,10 @@ private: LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; + + // Flag for specifying the Op requires the number of elements instead of the + // size in bytes. + bool requiresNumElements = false; }; } // namespace mlir diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index c55a62e..ef099e5 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -121,7 +121,7 @@ Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { void ConvertToLLVMPattern::getMemRefDescriptorSizes( Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl &sizes, - SmallVectorImpl &strides, Value &sizeBytes) const { + SmallVectorImpl &strides, Value &size, bool sizeInBytes) const { assert(isConvertibleAndHasIdentityMaps(memRefType) && "layout maps must have been normalized away"); assert(count(memRefType.getShape(), ShapedType::kDynamic) == @@ -143,14 +143,14 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes( for (auto i = memRefType.getRank(); i-- > 0;) { strides[i] = runningStride; - int64_t size = memRefType.getShape()[i]; - if (size == 0) + int64_t staticSize = memRefType.getShape()[i]; + if (staticSize == 0) continue; bool useSizeAsStride = stride == 1; - if (size == ShapedType::kDynamic) + if (staticSize == ShapedType::kDynamic) stride = ShapedType::kDynamic; if (stride != ShapedType::kDynamic) - stride *= size; + stride *= staticSize; if (useSizeAsStride) runningStride = sizes[i]; @@ -160,14 +160,17 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes( else runningStride = createIndexConstant(rewriter, loc, stride); } - - // Buffer size in bytes. - Type elementType = typeConverter->convertType(memRefType.getElementType()); - Type elementPtrType = getTypeConverter()->getPointerType(elementType); - Value nullPtr = rewriter.create(loc, elementPtrType); - Value gepPtr = rewriter.create(loc, elementPtrType, elementType, - nullPtr, runningStride); - sizeBytes = rewriter.create(loc, getIndexType(), gepPtr); + if (sizeInBytes) { + // Buffer size in bytes. + Type elementType = typeConverter->convertType(memRefType.getElementType()); + Type elementPtrType = getTypeConverter()->getPointerType(elementType); + Value nullPtr = rewriter.create(loc, elementPtrType); + Value gepPtr = rewriter.create( + loc, elementPtrType, elementType, nullPtr, runningStride); + size = rewriter.create(loc, getIndexType(), gepPtr); + } else { + size = runningStride; + } } Value ConvertToLLVMPattern::getSizeInBytes( @@ -186,13 +189,30 @@ Value ConvertToLLVMPattern::getSizeInBytes( } Value ConvertToLLVMPattern::getNumElements( - Location loc, ArrayRef shape, + Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const { + assert(count(memRefType.getShape(), ShapedType::kDynamic) == + static_cast(dynamicSizes.size()) && + "dynamicSizes size doesn't match dynamic sizes count in memref shape"); + + Value numElements = memRefType.getRank() == 0 + ? createIndexConstant(rewriter, loc, 1) + : nullptr; + unsigned dynamicIndex = 0; + // Compute the total number of memref elements. - Value numElements = - shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front(); - for (unsigned i = 1, e = shape.size(); i < e; ++i) - numElements = rewriter.create(loc, numElements, shape[i]); + for (int64_t staticSize : memRefType.getShape()) { + if (numElements) { + Value size = staticSize == ShapedType::kDynamic + ? dynamicSizes[dynamicIndex++] + : createIndexConstant(rewriter, loc, staticSize); + numElements = rewriter.create(loc, numElements, size); + } else { + numElements = staticSize == ShapedType::kDynamic + ? dynamicSizes[dynamicIndex++] + : createIndexConstant(rewriter, loc, staticSize); + } + } return numElements; } diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp index 2fa4315..b762758 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp @@ -156,6 +156,10 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign( elementPtrType, *getTypeConverter()); } +void AllocLikeOpLLVMLowering::setRequiresNumElements() { + requiresNumElements = true; +} + LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite( Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { @@ -169,13 +173,14 @@ LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite( // zero-dimensional memref, assume a scalar (size 1). SmallVector sizes; SmallVector strides; - Value sizeBytes; + Value size; + this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes, - strides, sizeBytes); + strides, size, !requiresNumElements); // Allocate the underlying buffer. auto [allocatedPtr, alignedPtr] = - this->allocateBuffer(rewriter, loc, sizeBytes, op); + this->allocateBuffer(rewriter, loc, size, op); // Create the MemRef descriptor. auto memRefDescriptor = this->createMemRefDescriptor( diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 47c2cdb..24ea1a6 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -85,13 +85,15 @@ private: struct AllocaOpLowering : public AllocLikeOpLLVMLowering { AllocaOpLowering(LLVMTypeConverter &converter) : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), - converter) {} + converter) { + setRequiresNumElements(); + } /// Allocates the underlying buffer using the right call. `allocatedBytePtr` /// is set to null for stack allocations. `accessAlignment` is set if /// alignment is needed post allocation (for eg. in conjunction with malloc). std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, - Location loc, Value sizeBytes, + Location loc, Value size, Operation *op) const override { // With alloca, one gets a pointer to the element type right away. @@ -104,9 +106,9 @@ struct AllocaOpLowering : public AllocLikeOpLLVMLowering { auto elementPtrType = getTypeConverter()->getPointerType(elementType, addrSpace); - auto allocatedElementPtr = rewriter.create( - loc, elementPtrType, elementType, sizeBytes, - allocaOp.getAlignment().value_or(0)); + auto allocatedElementPtr = + rewriter.create(loc, elementPtrType, elementType, size, + allocaOp.getAlignment().value_or(0)); return std::make_tuple(allocatedElementPtr, allocatedElementPtr); } diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir index e213bee..520d629 100644 --- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir @@ -91,12 +91,15 @@ gpu.module @test_module { %j = arith.constant 16 : index gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index, transpose} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3> // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64 - // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] + // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] + // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] + // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] + // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] // CHECK: %[[EL1:.*]] = llvm.extractvalue %[[D]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[EL2:.*]] = llvm.extractvalue %[[D]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[EL3:.*]] = llvm.extractvalue %[[D]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[EL4:.*]] = llvm.extractvalue %[[D]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> - // CHECK: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64 // CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64 // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64 @@ -107,12 +110,15 @@ gpu.module @test_module { // CHECK: llvm.return // CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 - // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] + // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] + // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] + // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] + // CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] // CHECK32: %[[EL1:.*]] = llvm.extractvalue %[[D]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK32: %[[EL2:.*]] = llvm.extractvalue %[[D]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK32: %[[EL3:.*]] = llvm.extractvalue %[[D]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK32: %[[EL4:.*]] = llvm.extractvalue %[[D]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> - // CHECK32: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<2 x i32>, array<2 x i32>)> + // CHECK32: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 // CHECK32: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32 // CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32 diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir index b703184..0b9d806 100644 --- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir @@ -86,10 +86,7 @@ func.func @dynamic_alloca(%arg0: index, %arg1: index) -> memref { // CHECK-DAG: %[[N:.*]] = builtin.unrealized_conversion_cast %[[Narg]] // CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK-NEXT: %[[num_elems:.*]] = llvm.mul %[[N]], %[[M]] : i64 -// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[num_elems]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 -// CHECK-NEXT: %[[sz_bytes:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to i64 -// CHECK-NEXT: %[[allocated:.*]] = llvm.alloca %[[sz_bytes]] x f32 : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[allocated:.*]] = llvm.alloca %[[num_elems]] x f32 : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[allocated]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[allocated]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir index 0f3f3e4..cc7b210 100644 --- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir @@ -79,10 +79,7 @@ func.func @static_alloca() -> memref<32x18xf32> { // CHECK: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : i64 // CHECK: %[[st2:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK: %[[num_elems:.*]] = llvm.mlir.constant(576 : index) : i64 -// CHECK: %[[null:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[num_elems]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 -// CHECK: %[[size_bytes:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to i64 -// CHECK: %[[allocated:.*]] = llvm.alloca %[[size_bytes]] x f32 : (i64) -> !llvm.ptr +// CHECK: %[[allocated:.*]] = llvm.alloca %[[num_elems]] x f32 : (i64) -> !llvm.ptr %0 = memref.alloca() : memref<32x18xf32> // Test with explicitly specified alignment. llvm.alloca takes care of the