[mlir][memref] fix overflow in realloc
authorPeiming Liu <peiming@google.com>
Fri, 23 Sep 2022 01:48:54 +0000 (01:48 +0000)
committerPeiming Liu <peiming@google.com>
Fri, 23 Sep 2022 03:07:23 +0000 (03:07 +0000)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D134511

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir

index 241f62e..8fe631b 100644 (file)
@@ -191,6 +191,11 @@ struct ReallocOpLoweringBase : public AllocationOpLLVMLowering {
     // Compute total byte size.
     auto dstByteSize =
         rewriter.create<LLVM::MulOp>(loc, dstNumElements, sizeInBytes);
+    // Since the src and dst memref are guarantee to have the same
+    // element type by the verifier, it is safe here to reuse the
+    // type size computed from dst memref.
+    auto srcByteSize =
+        rewriter.create<LLVM::MulOp>(loc, srcNumElements, sizeInBytes);
     // Allocate a new buffer.
     auto [dstRawPtr, dstAlignedPtr] =
         allocateBuffer(rewriter, loc, dstByteSize, op);
@@ -202,7 +207,7 @@ struct ReallocOpLoweringBase : public AllocationOpLLVMLowering {
       return rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr);
     };
     rewriter.create<LLVM::MemcpyOp>(loc, toVoidPtr(dstAlignedPtr),
-                                    toVoidPtr(srcAlignedPtr), dstByteSize,
+                                    toVoidPtr(srcAlignedPtr), srcByteSize,
                                     isVolatile);
     // Deallocate the old buffer.
     LLVM::LLVMFuncOp freeFunc =
index 3cd8fb3..821f298 100644 (file)
@@ -633,22 +633,23 @@ func.func @ranked_unranked() {
 // CHECK-SAME:      %[[arg1:.*]]: index) -> memref<?xf32> {
 func.func @realloc_dynamic(%in: memref<?xf32>, %d: index) -> memref<?xf32>{
 // CHECK:           %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]]
-// CHECK:           %[[drc_dim:.*]] = llvm.extractvalue %[[descriptor]][3, 0]
+// CHECK:           %[[src_dim:.*]] = llvm.extractvalue %[[descriptor]][3, 0]
 // CHECK:           %[[dst_dim:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : index to i64
-// CHECK:           %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[drc_dim]] : i64
+// CHECK:           %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[src_dim]] : i64
 // CHECK:           llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]]
 // CHECK:           ^bb1:
 // CHECK:           %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr<f32>
 // CHECK:           %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
 // CHECK:           %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<f32> to i64
 // CHECK:           %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// CHECK:           %[[src_size:.*]] = llvm.mul %[[src_dim]], %[[dst_es]]
 // CHECK:           %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[dst_size]])
 // CHECK:           %[[new_buffer:.*]] = llvm.bitcast %[[new_buffer_raw]] : !llvm.ptr<i8> to !llvm.ptr<f32>
 // CHECK:           %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1]
 // CHECK:           %[[volatile:.*]] = llvm.mlir.constant(false) : i1
 // CHECK-DAG:       %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer]] : !llvm.ptr<f32> to !llvm.ptr<i8>
 // CHECK-DAG:       %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
-// CHECK:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// CHECK:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]])
 // CHECK:           %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
 // CHECK:           %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
 // CHECK:           llvm.call @free(%[[old_buffer_unaligned_void]])
@@ -683,6 +684,7 @@ func.func @realloc_dynamic_alignment(%in: memref<?xf32>, %d: index) -> memref<?x
 // CHECK:           %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
 // CHECK:           %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<f32> to i64
 // CHECK:           %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// CHECK:           %[[src_size:.*]] = llvm.mul %[[drc_dim]], %[[dst_es]]
 // CHECK:           %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64
 // CHECK:           %[[adjust_dst_size:.*]] = llvm.add %[[dst_size]], %[[alignment]]
 // CHECK:           %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[adjust_dst_size]])
@@ -698,7 +700,7 @@ func.func @realloc_dynamic_alignment(%in: memref<?xf32>, %d: index) -> memref<?x
 // CHECK:           %[[volatile:.*]] = llvm.mlir.constant(false) : i1
 // CHECK-DAG:       %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
 // CHECK-DAG:       %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
-// CHECK:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// CHECK:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]])
 // CHECK:           %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
 // CHECK:           %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
 // CHECK:           llvm.call @free(%[[old_buffer_unaligned_void]])
@@ -720,6 +722,7 @@ func.func @realloc_dynamic_alignment(%in: memref<?xf32>, %d: index) -> memref<?x
 // ALIGNED-ALLOC:           %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
 // ALIGNED-ALLOC:           %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<f32> to i64
 // ALIGNED-ALLOC:           %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// ALIGNED-ALLOC:           %[[src_size:.*]] = llvm.mul %[[drc_dim]], %[[dst_es]]
 // ALIGNED-ALLOC-DAG:       %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64
 // ALIGNED-ALLOC-DAG:       %[[const_1:.*]] = llvm.mlir.constant(1 : index) : i64
 // ALIGNED-ALLOC:           %[[alignment_m1:.*]] = llvm.sub %[[alignment]], %[[const_1]]
@@ -732,7 +735,7 @@ func.func @realloc_dynamic_alignment(%in: memref<?xf32>, %d: index) -> memref<?x
 // ALIGNED-ALLOC:           %[[volatile:.*]] = llvm.mlir.constant(false) : i1
 // ALIGNED-ALLOC-DAG:       %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
 // ALIGNED-ALLOC-DAG:       %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
-// ALIGNED-ALLOC:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// ALIGNED-ALLOC:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]])
 // ALIGNED-ALLOC:           %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
 // ALIGNED-ALLOC:           %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
 // ALIGNED-ALLOC:           llvm.call @free(%[[old_buffer_unaligned_void]])
index cabc84f..0239559 100644 (file)
@@ -354,13 +354,14 @@ func.func @realloc_static(%in: memref<2xi32>) -> memref<4xi32>{
 // CHECK:           %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
 // CHECK:           %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<i32> to i64
 // CHECK:           %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// CHECK:           %[[src_size:.*]] = llvm.mul %[[src_dim]], %[[dst_es]]
 // CHECK:           %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[dst_size]])
 // CHECK:           %[[new_buffer:.*]] = llvm.bitcast %[[new_buffer_raw]] : !llvm.ptr<i8> to !llvm.ptr<i32>
 // CHECK:           %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1]
 // CHECK:           %[[volatile:.*]] = llvm.mlir.constant(false) : i1
 // CHECK-DAG:       %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer]] : !llvm.ptr<i32> to !llvm.ptr<i8>
 // CHECK-DAG:       %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<i32> to !llvm.ptr<i8>
-// CHECK:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// CHECK:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]])
 // CHECK:           %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
 // CHECK:           %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<i32> to !llvm.ptr<i8>
 // CHECK:           llvm.call @free(%[[old_buffer_unaligned_void]])
@@ -391,6 +392,7 @@ func.func @realloc_static_alignment(%in: memref<2xf32>) -> memref<4xf32>{
 // CHECK:           %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
 // CHECK:           %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<f32> to i64
 // CHECK:           %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// CHECK:           %[[src_size:.*]] = llvm.mul %[[src_dim]], %[[dst_es]]
 // CHECK:           %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64
 // CHECK:           %[[adjust_dst_size:.*]] = llvm.add %[[dst_size]], %[[alignment]]
 // CHECK:           %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[adjust_dst_size]])
@@ -406,7 +408,7 @@ func.func @realloc_static_alignment(%in: memref<2xf32>) -> memref<4xf32>{
 // CHECK:           %[[volatile:.*]] = llvm.mlir.constant(false) : i1
 // CHECK-DAG:       %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
 // CHECK-DAG:       %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
-// CHECK:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// CHECK:           "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]])
 // CHECK:           %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
 // CHECK:           %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
 // CHECK:           llvm.call @free(%[[old_buffer_unaligned_void]])
index 035c90c..ff57bfe 100644 (file)
@@ -1,6 +1,3 @@
-// FIXME: re-enable when sanitizer issue is resolved
-// UNSUPPORTED: asan
-//
 // RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=false | \
 // RUN: mlir-cpu-runner \
 // RUN:  -e entry -entry-point-result=void  \