From a50c269c7372f5f0373fe3876ed8f8acf0e2f12d Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Mon, 7 Nov 2022 17:02:19 +0100 Subject: [PATCH] [InstCombine] Handle load smaller than one byte in memset forward APInt::getSplat() requires that the new size is >= the original one. If we're loading less than 8 bits, truncate instead. Fixes https://github.com/llvm/llvm-project/issues/58845. --- llvm/lib/Analysis/Loads.cpp | 12 ++++++++---- .../Transforms/InstCombine/load-store-forward.ll | 20 ++++++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp index 93faefa..bc16c00 100644 --- a/llvm/lib/Analysis/Loads.cpp +++ b/llvm/lib/Analysis/Loads.cpp @@ -532,13 +532,17 @@ static Value *getAvailableLoadStore(Instruction *Inst, const Value *Ptr, if (IsLoadCSE) *IsLoadCSE = false; + TypeSize LoadTypeSize = DL.getTypeSizeInBits(AccessTy); + if (LoadTypeSize.isScalable()) + return nullptr; + // Make sure the read bytes are contained in the memset. - TypeSize LoadSize = DL.getTypeSizeInBits(AccessTy); - if (LoadSize.isScalable() || - (Len->getValue() * 8).ult(LoadSize.getFixedSize())) + uint64_t LoadSize = LoadTypeSize.getFixedSize(); + if ((Len->getValue() * 8).ult(LoadSize)) return nullptr; - APInt Splat = APInt::getSplat(LoadSize.getFixedSize(), Val->getValue()); + APInt Splat = LoadSize >= 8 ? APInt::getSplat(LoadSize, Val->getValue()) + : Val->getValue().trunc(LoadSize); ConstantInt *SplatC = ConstantInt::get(MSI->getContext(), Splat); if (CastInst::isBitOrNoopPointerCastable(SplatC->getType(), AccessTy, DL)) return SplatC; diff --git a/llvm/test/Transforms/InstCombine/load-store-forward.ll b/llvm/test/Transforms/InstCombine/load-store-forward.ll index 5a847cd..6be5f6e 100644 --- a/llvm/test/Transforms/InstCombine/load-store-forward.ll +++ b/llvm/test/Transforms/InstCombine/load-store-forward.ll @@ -284,6 +284,16 @@ define i27 @load_after_memset_0_non_byte_sized(ptr %a) { ret i27 %v } +define i1 @load_after_memset_0_i1(ptr %a) { +; CHECK-LABEL: @load_after_memset_0_i1( +; CHECK-NEXT: call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 0, i64 16, i1 false) +; CHECK-NEXT: ret i1 false +; + call void @llvm.memset.p0.i64(ptr %a, i8 0, i64 16, i1 false) + %v = load i1, ptr %a + ret i1 %v +} + define <4 x i8> @load_after_memset_0_vec(ptr %a) { ; CHECK-LABEL: @load_after_memset_0_vec( ; CHECK-NEXT: call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 0, i64 16, i1 false) @@ -324,6 +334,16 @@ define i27 @load_after_memset_1_non_byte_sized(ptr %a) { ret i27 %v } +define i1 @load_after_memset_1_i1(ptr %a) { +; CHECK-LABEL: @load_after_memset_1_i1( +; CHECK-NEXT: call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 1, i64 16, i1 false) +; CHECK-NEXT: ret i1 true +; + call void @llvm.memset.p0.i64(ptr %a, i8 1, i64 16, i1 false) + %v = load i1, ptr %a + ret i1 %v +} + define <4 x i8> @load_after_memset_1_vec(ptr %a) { ; CHECK-LABEL: @load_after_memset_1_vec( ; CHECK-NEXT: call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 1, i64 16, i1 false) -- 2.7.4