// Compute total byte size.
auto dstByteSize =
rewriter.create<LLVM::MulOp>(loc, dstNumElements, sizeInBytes);
+ // Since the src and dst memref are guarantee to have the same
+ // element type by the verifier, it is safe here to reuse the
+ // type size computed from dst memref.
+ auto srcByteSize =
+ rewriter.create<LLVM::MulOp>(loc, srcNumElements, sizeInBytes);
// Allocate a new buffer.
auto [dstRawPtr, dstAlignedPtr] =
allocateBuffer(rewriter, loc, dstByteSize, op);
return rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr);
};
rewriter.create<LLVM::MemcpyOp>(loc, toVoidPtr(dstAlignedPtr),
- toVoidPtr(srcAlignedPtr), dstByteSize,
+ toVoidPtr(srcAlignedPtr), srcByteSize,
isVolatile);
// Deallocate the old buffer.
LLVM::LLVMFuncOp freeFunc =
// CHECK-SAME: %[[arg1:.*]]: index) -> memref<?xf32> {
func.func @realloc_dynamic(%in: memref<?xf32>, %d: index) -> memref<?xf32>{
// CHECK: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]]
-// CHECK: %[[drc_dim:.*]] = llvm.extractvalue %[[descriptor]][3, 0]
+// CHECK: %[[src_dim:.*]] = llvm.extractvalue %[[descriptor]][3, 0]
// CHECK: %[[dst_dim:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : index to i64
-// CHECK: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[drc_dim]] : i64
+// CHECK: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[src_dim]] : i64
// CHECK: llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]]
// CHECK: ^bb1:
// CHECK: %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr<f32>
// CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
// CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<f32> to i64
// CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// CHECK: %[[src_size:.*]] = llvm.mul %[[src_dim]], %[[dst_es]]
// CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[dst_size]])
// CHECK: %[[new_buffer:.*]] = llvm.bitcast %[[new_buffer_raw]] : !llvm.ptr<i8> to !llvm.ptr<f32>
// CHECK: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1]
// CHECK: %[[volatile:.*]] = llvm.mlir.constant(false) : i1
// CHECK-DAG: %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer]] : !llvm.ptr<f32> to !llvm.ptr<i8>
// CHECK-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
-// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]])
// CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
// CHECK: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
// CHECK: llvm.call @free(%[[old_buffer_unaligned_void]])
// CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
// CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<f32> to i64
// CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// CHECK: %[[src_size:.*]] = llvm.mul %[[drc_dim]], %[[dst_es]]
// CHECK: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64
// CHECK: %[[adjust_dst_size:.*]] = llvm.add %[[dst_size]], %[[alignment]]
// CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[adjust_dst_size]])
// CHECK: %[[volatile:.*]] = llvm.mlir.constant(false) : i1
// CHECK-DAG: %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
// CHECK-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
-// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]])
// CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
// CHECK: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
// CHECK: llvm.call @free(%[[old_buffer_unaligned_void]])
// ALIGNED-ALLOC: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
// ALIGNED-ALLOC: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<f32> to i64
// ALIGNED-ALLOC: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// ALIGNED-ALLOC: %[[src_size:.*]] = llvm.mul %[[drc_dim]], %[[dst_es]]
// ALIGNED-ALLOC-DAG: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64
// ALIGNED-ALLOC-DAG: %[[const_1:.*]] = llvm.mlir.constant(1 : index) : i64
// ALIGNED-ALLOC: %[[alignment_m1:.*]] = llvm.sub %[[alignment]], %[[const_1]]
// ALIGNED-ALLOC: %[[volatile:.*]] = llvm.mlir.constant(false) : i1
// ALIGNED-ALLOC-DAG: %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
// ALIGNED-ALLOC-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
-// ALIGNED-ALLOC: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// ALIGNED-ALLOC: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]])
// ALIGNED-ALLOC: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
// ALIGNED-ALLOC: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
// ALIGNED-ALLOC: llvm.call @free(%[[old_buffer_unaligned_void]])
// CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
// CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<i32> to i64
// CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// CHECK: %[[src_size:.*]] = llvm.mul %[[src_dim]], %[[dst_es]]
// CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[dst_size]])
// CHECK: %[[new_buffer:.*]] = llvm.bitcast %[[new_buffer_raw]] : !llvm.ptr<i8> to !llvm.ptr<i32>
// CHECK: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1]
// CHECK: %[[volatile:.*]] = llvm.mlir.constant(false) : i1
// CHECK-DAG: %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer]] : !llvm.ptr<i32> to !llvm.ptr<i8>
// CHECK-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<i32> to !llvm.ptr<i8>
-// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]])
// CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
// CHECK: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<i32> to !llvm.ptr<i8>
// CHECK: llvm.call @free(%[[old_buffer_unaligned_void]])
// CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
// CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<f32> to i64
// CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// CHECK: %[[src_size:.*]] = llvm.mul %[[src_dim]], %[[dst_es]]
// CHECK: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64
// CHECK: %[[adjust_dst_size:.*]] = llvm.add %[[dst_size]], %[[alignment]]
// CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[adjust_dst_size]])
// CHECK: %[[volatile:.*]] = llvm.mlir.constant(false) : i1
// CHECK-DAG: %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
// CHECK-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
-// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]], %[[volatile]])
// CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
// CHECK: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
// CHECK: llvm.call @free(%[[old_buffer_unaligned_void]])