From f025053977f330152da081b7060a5d9cba0a9e22 Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Sun, 27 Jun 2021 16:19:21 +0200 Subject: [PATCH] [MemCpyOpt] Handle unusual memcpy element type Apparently, it is legal to use memcpy/memset with pointer types other than i8*. Prior to 81fcdae68c5ff656c30032fd26c6a21af4c51dbb this case was silently miscompiled, as the i8 offset calculation was performed on some other type. Now it would crash due to a type mismatch. Fix this by inserting an explicit bitcast to i8*. --- llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp | 4 +++- .../MemCpyOpt/memset-memcpy-redundant-memset.ll | 23 ++++++++++++++++++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp index 6016779..b0b1a3f 100644 --- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -1215,7 +1215,9 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, Value *MemsetLen = Builder.CreateSelect( Ule, ConstantInt::getNullValue(DestSize->getType()), SizeDiff); Instruction *NewMemSet = Builder.CreateMemSet( - Builder.CreateGEP(Builder.getInt8Ty(), Dest, SrcSize), + Builder.CreateGEP(Builder.getInt8Ty(), + Builder.CreatePointerCast(Dest, Builder.getInt8PtrTy()), + SrcSize), MemSet->getOperand(1), MemsetLen, MaybeAlign(Align)); if (MSSAU) { diff --git a/llvm/test/Transforms/MemCpyOpt/memset-memcpy-redundant-memset.ll b/llvm/test/Transforms/MemCpyOpt/memset-memcpy-redundant-memset.ll index 9873216..1a91d2e 100644 --- a/llvm/test/Transforms/MemCpyOpt/memset-memcpy-redundant-memset.ll +++ b/llvm/test/Transforms/MemCpyOpt/memset-memcpy-redundant-memset.ll @@ -296,8 +296,9 @@ define void @test_opaque_ptrs(ptr %src, i64 %src_size, ptr noalias %dst, i64 %ds ; CHECK-NEXT: [[TMP1:%.*]] = icmp ule i64 [[DST_SIZE:%.*]], [[SRC_SIZE:%.*]] ; CHECK-NEXT: [[TMP2:%.*]] = sub i64 [[DST_SIZE]], [[SRC_SIZE]] ; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[TMP1]], i64 0, i64 [[TMP2]] -; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[DST:%.*]], i64 [[SRC_SIZE]] -; CHECK-NEXT: call void @llvm.memset.p0.i64(ptr align 1 [[TMP4]], i8 [[C:%.*]], i64 [[TMP3]], i1 false) +; CHECK-NEXT: [[TMP4:%.*]] = bitcast ptr [[DST:%.*]] to i8* +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, i8* [[TMP4]], i64 [[SRC_SIZE]] +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 1 [[TMP5]], i8 [[C:%.*]], i64 [[TMP3]], i1 false) ; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr [[DST]], ptr [[SRC:%.*]], i64 [[SRC_SIZE]], i1 false) ; CHECK-NEXT: ret void ; @@ -306,6 +307,22 @@ define void @test_opaque_ptrs(ptr %src, i64 %src_size, ptr noalias %dst, i64 %ds ret void } +define void @test_weird_element_type(i16* %src, i64 %src_size, i16* noalias %dst, i64 %dst_size, i8 %c) { +; CHECK-LABEL: @test_weird_element_type( +; CHECK-NEXT: [[TMP1:%.*]] = icmp ule i64 [[DST_SIZE:%.*]], [[SRC_SIZE:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = sub i64 [[DST_SIZE]], [[SRC_SIZE]] +; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[TMP1]], i64 0, i64 [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = bitcast i16* [[DST:%.*]] to i8* +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, i8* [[TMP4]], i64 [[SRC_SIZE]] +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 1 [[TMP5]], i8 [[C:%.*]], i64 [[TMP3]], i1 false) +; CHECK-NEXT: call void @llvm.memcpy.p0i16.p0i16.i64(i16* [[DST]], i16* [[SRC:%.*]], i64 [[SRC_SIZE]], i1 false) +; CHECK-NEXT: ret void +; + call void @llvm.memset.p0i16.i64(i16* %dst, i8 %c, i64 %dst_size, i1 false) + call void @llvm.memcpy.p0i16.p0i16.i64(i16* %dst, i16* %src, i64 %src_size, i1 false) + ret void +} + declare void @llvm.memset.p0i8.i64(i8* nocapture, i8, i64, i1) declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture, i8* nocapture readonly, i64, i1) declare void @llvm.memset.p0i8.i32(i8* nocapture, i8, i32, i1) @@ -314,4 +331,6 @@ declare void @llvm.memset.p0i8.i128(i8* nocapture, i8, i128, i1) declare void @llvm.memcpy.p0i8.p0i8.i128(i8* nocapture, i8* nocapture readonly, i128, i1) declare void @llvm.memset.p0.i64(ptr nocapture, i8, i64, i1) declare void @llvm.memcpy.p0.p0.i64(ptr nocapture, ptr nocapture readonly, i64, i1) +declare void @llvm.memset.p0i16.i64(i16* nocapture, i8, i64, i1) +declare void @llvm.memcpy.p0i16.p0i16.i64(i16* nocapture, i16* nocapture readonly, i64, i1) declare void @call() -- 2.7.4