[InstCombine] Handle load smaller than one byte in memset forward
authorNikita Popov <npopov@redhat.com>
Mon, 7 Nov 2022 16:02:19 +0000 (17:02 +0100)
committerNikita Popov <npopov@redhat.com>
Mon, 7 Nov 2022 16:04:27 +0000 (17:04 +0100)
APInt::getSplat() requires that the new size is >= the original
one. If we're loading less than 8 bits, truncate instead.

Fixes https://github.com/llvm/llvm-project/issues/58845.

llvm/lib/Analysis/Loads.cpp
llvm/test/Transforms/InstCombine/load-store-forward.ll

index 93faefa..bc16c00 100644 (file)
@@ -532,13 +532,17 @@ static Value *getAvailableLoadStore(Instruction *Inst, const Value *Ptr,
     if (IsLoadCSE)
       *IsLoadCSE = false;
 
+    TypeSize LoadTypeSize = DL.getTypeSizeInBits(AccessTy);
+    if (LoadTypeSize.isScalable())
+      return nullptr;
+
     // Make sure the read bytes are contained in the memset.
-    TypeSize LoadSize = DL.getTypeSizeInBits(AccessTy);
-    if (LoadSize.isScalable() ||
-        (Len->getValue() * 8).ult(LoadSize.getFixedSize()))
+    uint64_t LoadSize = LoadTypeSize.getFixedSize();
+    if ((Len->getValue() * 8).ult(LoadSize))
       return nullptr;
 
-    APInt Splat = APInt::getSplat(LoadSize.getFixedSize(), Val->getValue());
+    APInt Splat = LoadSize >= 8 ? APInt::getSplat(LoadSize, Val->getValue())
+                                : Val->getValue().trunc(LoadSize);
     ConstantInt *SplatC = ConstantInt::get(MSI->getContext(), Splat);
     if (CastInst::isBitOrNoopPointerCastable(SplatC->getType(), AccessTy, DL))
       return SplatC;
index 5a847cd..6be5f6e 100644 (file)
@@ -284,6 +284,16 @@ define i27 @load_after_memset_0_non_byte_sized(ptr %a) {
   ret i27 %v
 }
 
+define i1 @load_after_memset_0_i1(ptr %a) {
+; CHECK-LABEL: @load_after_memset_0_i1(
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 0, i64 16, i1 false)
+; CHECK-NEXT:    ret i1 false
+;
+  call void @llvm.memset.p0.i64(ptr %a, i8 0, i64 16, i1 false)
+  %v = load i1, ptr %a
+  ret i1 %v
+}
+
 define <4 x i8> @load_after_memset_0_vec(ptr %a) {
 ; CHECK-LABEL: @load_after_memset_0_vec(
 ; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 0, i64 16, i1 false)
@@ -324,6 +334,16 @@ define i27 @load_after_memset_1_non_byte_sized(ptr %a) {
   ret i27 %v
 }
 
+define i1 @load_after_memset_1_i1(ptr %a) {
+; CHECK-LABEL: @load_after_memset_1_i1(
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 1, i64 16, i1 false)
+; CHECK-NEXT:    ret i1 true
+;
+  call void @llvm.memset.p0.i64(ptr %a, i8 1, i64 16, i1 false)
+  %v = load i1, ptr %a
+  ret i1 %v
+}
+
 define <4 x i8> @load_after_memset_1_vec(ptr %a) {
 ; CHECK-LABEL: @load_after_memset_1_vec(
 ; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 1, i64 16, i1 false)