From 8b587113b746f31b63fd6473083df78cef30a72e Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Fri, 23 Sep 2022 01:48:54 +0000 Subject: [PATCH] [mlir][memref] fix overflow in realloc Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D134511 --- mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 7 ++++++- .../Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir | 13 ++++++++----- .../Conversion/MemRefToLLVM/convert-static-memref-ops.mlir | 6 ++++-- .../Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir | 3 --- 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 241f62e..8fe631b 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -191,6 +191,11 @@ struct ReallocOpLoweringBase : public AllocationOpLLVMLowering { // Compute total byte size. auto dstByteSize = rewriter.create(loc, dstNumElements, sizeInBytes); + // Since the src and dst memref are guarantee to have the same + // element type by the verifier, it is safe here to reuse the + // type size computed from dst memref. + auto srcByteSize = + rewriter.create(loc, srcNumElements, sizeInBytes); // Allocate a new buffer. auto [dstRawPtr, dstAlignedPtr] = allocateBuffer(rewriter, loc, dstByteSize, op); @@ -202,7 +207,7 @@ struct ReallocOpLoweringBase : public AllocationOpLLVMLowering { return rewriter.create(loc, getVoidPtrType(), ptr); }; rewriter.create(loc, toVoidPtr(dstAlignedPtr), - toVoidPtr(srcAlignedPtr), dstByteSize, + toVoidPtr(srcAlignedPtr), srcByteSize, isVolatile); // Deallocate the old buffer. LLVM::LLVMFuncOp freeFunc = diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir index 3cd8fb3..821f298 100644 --- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir @@ -633,22 +633,23 @@ func.func @ranked_unranked() { // CHECK-SAME: %[[arg1:.*]]: index) -> memref { func.func @realloc_dynamic(%in: memref, %d: index) -> memref{ // CHECK: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]] -// CHECK: %[[drc_dim:.*]] = llvm.extractvalue %[[descriptor]][3, 0] +// CHECK: %[[src_dim:.*]] = llvm.extractvalue %[[descriptor]][3, 0] // CHECK: %[[dst_dim:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : index to i64 -// CHECK: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[drc_dim]] : i64 +// CHECK: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[src_dim]] : i64 // CHECK: llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]] // CHECK: ^bb1: // CHECK: %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr // CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1] // CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr to i64 // CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]] +// CHECK: %[[src_size:.*]] = llvm.mul %[[src_dim]], %[[dst_es]] // CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[dst_size]]) // CHECK: %[[new_buffer:.*]] = llvm.bitcast %[[new_buffer_raw]] : !llvm.ptr to !llvm.ptr // CHECK: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1] // CHECK: %[[volatile:.*]] = llvm.mlir.constant(false) : i1 // CHECK-DAG: %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer]] : !llvm.ptr to !llvm.ptr // CHECK-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr to !llvm.ptr -// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]]) +// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]]) // CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0] // CHECK: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr to !llvm.ptr // CHECK: llvm.call @free(%[[old_buffer_unaligned_void]]) @@ -683,6 +684,7 @@ func.func @realloc_dynamic_alignment(%in: memref, %d: index) -> memref to i64 // CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]] +// CHECK: %[[src_size:.*]] = llvm.mul %[[drc_dim]], %[[dst_es]] // CHECK: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64 // CHECK: %[[adjust_dst_size:.*]] = llvm.add %[[dst_size]], %[[alignment]] // CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[adjust_dst_size]]) @@ -698,7 +700,7 @@ func.func @realloc_dynamic_alignment(%in: memref, %d: index) -> memref to !llvm.ptr // CHECK-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr to !llvm.ptr -// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]]) +// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]]) // CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0] // CHECK: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr to !llvm.ptr // CHECK: llvm.call @free(%[[old_buffer_unaligned_void]]) @@ -720,6 +722,7 @@ func.func @realloc_dynamic_alignment(%in: memref, %d: index) -> memref to i64 // ALIGNED-ALLOC: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]] +// ALIGNED-ALLOC: %[[src_size:.*]] = llvm.mul %[[drc_dim]], %[[dst_es]] // ALIGNED-ALLOC-DAG: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64 // ALIGNED-ALLOC-DAG: %[[const_1:.*]] = llvm.mlir.constant(1 : index) : i64 // ALIGNED-ALLOC: %[[alignment_m1:.*]] = llvm.sub %[[alignment]], %[[const_1]] @@ -732,7 +735,7 @@ func.func @realloc_dynamic_alignment(%in: memref, %d: index) -> memref to !llvm.ptr // ALIGNED-ALLOC-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr to !llvm.ptr -// ALIGNED-ALLOC: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]]) +// ALIGNED-ALLOC: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]]) // ALIGNED-ALLOC: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0] // ALIGNED-ALLOC: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr to !llvm.ptr // ALIGNED-ALLOC: llvm.call @free(%[[old_buffer_unaligned_void]]) diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir index cabc84f..0239559 100644 --- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir @@ -354,13 +354,14 @@ func.func @realloc_static(%in: memref<2xi32>) -> memref<4xi32>{ // CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1] // CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr to i64 // CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]] +// CHECK: %[[src_size:.*]] = llvm.mul %[[src_dim]], %[[dst_es]] // CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[dst_size]]) // CHECK: %[[new_buffer:.*]] = llvm.bitcast %[[new_buffer_raw]] : !llvm.ptr to !llvm.ptr // CHECK: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1] // CHECK: %[[volatile:.*]] = llvm.mlir.constant(false) : i1 // CHECK-DAG: %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer]] : !llvm.ptr to !llvm.ptr // CHECK-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr to !llvm.ptr -// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]]) +// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]]) // CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0] // CHECK: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr to !llvm.ptr // CHECK: llvm.call @free(%[[old_buffer_unaligned_void]]) @@ -391,6 +392,7 @@ func.func @realloc_static_alignment(%in: memref<2xf32>) -> memref<4xf32>{ // CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1] // CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr to i64 // CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]] +// CHECK: %[[src_size:.*]] = llvm.mul %[[src_dim]], %[[dst_es]] // CHECK: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64 // CHECK: %[[adjust_dst_size:.*]] = llvm.add %[[dst_size]], %[[alignment]] // CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[adjust_dst_size]]) @@ -406,7 +408,7 @@ func.func @realloc_static_alignment(%in: memref<2xf32>) -> memref<4xf32>{ // CHECK: %[[volatile:.*]] = llvm.mlir.constant(false) : i1 // CHECK-DAG: %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer_aligned]] : !llvm.ptr to !llvm.ptr // CHECK-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr to !llvm.ptr -// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]]) +// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]]) // CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0] // CHECK: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr to !llvm.ptr // CHECK: llvm.call @free(%[[old_buffer_unaligned_void]]) diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir index 035c90c..ff57bfe 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir @@ -1,6 +1,3 @@ -// FIXME: re-enable when sanitizer issue is resolved -// UNSUPPORTED: asan -// // RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=false | \ // RUN: mlir-cpu-runner \ // RUN: -e entry -entry-point-result=void \ -- 2.7.4