[BasicAA] Model implicit trunc of GEP indices
authorNikita Popov <nikita.ppv@gmail.com>
Wed, 29 Sep 2021 20:18:54 +0000 (22:18 +0200)
committerNikita Popov <nikita.ppv@gmail.com>
Fri, 22 Oct 2021 21:47:02 +0000 (23:47 +0200)
GEP indices larger than the GEP index size are implicitly truncated
to the index size. BasicAA currently doesn't model this, resulting
in incorrect alias analysis results.

Fix this by explicitly modelling truncation in CastedValue in the
same way we do zext and sext. Additionally we need to disable a
number of optimizations for truncated values, in particular
"non-zero" and "non-equal" may no longer hold after truncation.
I believe the constant offset heuristic is also not necessarily
correct for truncated values, but wasn't able to come up with a
test for that one.

A possible followup here would be to use the new mechanism to
model explicit trunc as well (which should be much more common,
as it is the canonical form). This is straightforward, but omitted
here to separate the correctness fix from the analysis improvement.

(Side note: While I say "index size" above, BasicAA currently uses
the pointer size instead. Something for another day...)

Differential Revision: https://reviews.llvm.org/D110977

llvm/lib/Analysis/BasicAliasAnalysis.cpp
llvm/test/Analysis/BasicAA/gep-implicit-trunc-32-bit-pointers.ll

index 865db9f..25b6d9b 100644 (file)
@@ -264,43 +264,55 @@ void EarliestEscapeInfo::removeInstruction(Instruction *I) {
 //===----------------------------------------------------------------------===//
 
 namespace {
-/// Represents zext(sext(V)).
+/// Represents zext(sext(trunc(V))).
 struct CastedValue {
   const Value *V;
   unsigned ZExtBits = 0;
   unsigned SExtBits = 0;
+  unsigned TruncBits = 0;
 
   explicit CastedValue(const Value *V) : V(V) {}
-  explicit CastedValue(const Value *V, unsigned ZExtBits, unsigned SExtBits)
-      : V(V), ZExtBits(ZExtBits), SExtBits(SExtBits) {}
+  explicit CastedValue(const Value *V, unsigned ZExtBits, unsigned SExtBits,
+                       unsigned TruncBits)
+      : V(V), ZExtBits(ZExtBits), SExtBits(SExtBits), TruncBits(TruncBits) {}
 
   unsigned getBitWidth() const {
-    return V->getType()->getPrimitiveSizeInBits() + ZExtBits + SExtBits;
+    return V->getType()->getPrimitiveSizeInBits() - TruncBits + ZExtBits +
+           SExtBits;
   }
 
   CastedValue withValue(const Value *NewV) const {
-    return CastedValue(NewV, ZExtBits, SExtBits);
+    return CastedValue(NewV, ZExtBits, SExtBits, TruncBits);
   }
 
   /// Replace V with zext(NewV)
   CastedValue withZExtOfValue(const Value *NewV) const {
     unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() -
                         NewV->getType()->getPrimitiveSizeInBits();
+    if (ExtendBy <= TruncBits)
+      return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy);
+
     // zext(sext(zext(NewV))) == zext(zext(zext(NewV)))
-    return CastedValue(NewV, ZExtBits + SExtBits + ExtendBy, 0);
+    ExtendBy -= TruncBits;
+    return CastedValue(NewV, ZExtBits + SExtBits + ExtendBy, 0, 0);
   }
 
   /// Replace V with sext(NewV)
   CastedValue withSExtOfValue(const Value *NewV) const {
     unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() -
                         NewV->getType()->getPrimitiveSizeInBits();
+    if (ExtendBy <= TruncBits)
+      return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy);
+
     // zext(sext(sext(NewV)))
-    return CastedValue(NewV, ZExtBits, SExtBits + ExtendBy);
+    ExtendBy -= TruncBits;
+    return CastedValue(NewV, ZExtBits, SExtBits + ExtendBy, 0);
   }
 
   APInt evaluateWith(APInt N) const {
     assert(N.getBitWidth() == V->getType()->getPrimitiveSizeInBits() &&
            "Incompatible bit width");
+    if (TruncBits) N = N.trunc(N.getBitWidth() - TruncBits);
     if (SExtBits) N = N.sext(N.getBitWidth() + SExtBits);
     if (ZExtBits) N = N.zext(N.getBitWidth() + ZExtBits);
     return N;
@@ -309,6 +321,7 @@ struct CastedValue {
   KnownBits evaluateWith(KnownBits N) const {
     assert(N.getBitWidth() == V->getType()->getPrimitiveSizeInBits() &&
            "Incompatible bit width");
+    if (TruncBits) N = N.trunc(N.getBitWidth() - TruncBits);
     if (SExtBits) N = N.sext(N.getBitWidth() + SExtBits);
     if (ZExtBits) N = N.zext(N.getBitWidth() + ZExtBits);
     return N;
@@ -317,6 +330,7 @@ struct CastedValue {
   ConstantRange evaluateWith(ConstantRange N) const {
     assert(N.getBitWidth() == V->getType()->getPrimitiveSizeInBits() &&
            "Incompatible bit width");
+    if (TruncBits) N = N.truncate(N.getBitWidth() - TruncBits);
     if (SExtBits) N = N.signExtend(N.getBitWidth() + SExtBits);
     if (ZExtBits) N = N.zeroExtend(N.getBitWidth() + ZExtBits);
     return N;
@@ -325,15 +339,17 @@ struct CastedValue {
   bool canDistributeOver(bool NUW, bool NSW) const {
     // zext(x op<nuw> y) == zext(x) op<nuw> zext(y)
     // sext(x op<nsw> y) == sext(x) op<nsw> sext(y)
+    // trunc(x op y) == trunc(x) op trunc(y)
     return (!ZExtBits || NUW) && (!SExtBits || NSW);
   }
 
   bool hasSameCastsAs(const CastedValue &Other) const {
-    return ZExtBits == Other.ZExtBits && SExtBits == Other.SExtBits;
+    return ZExtBits == Other.ZExtBits && SExtBits == Other.SExtBits &&
+           TruncBits == Other.TruncBits;
   }
 };
 
-/// Represents zext(sext(V)) * Scale + Offset.
+/// Represents zext(sext(trunc(V))) * Scale + Offset.
 struct LinearExpression {
   CastedValue Val;
   APInt Scale;
@@ -380,6 +396,11 @@ static LinearExpression GetLinearExpression(
       if (!Val.canDistributeOver(NUW, NSW))
         return Val;
 
+      // While we can distribute over trunc, we cannot preserve nowrap flags
+      // in that case.
+      if (Val.TruncBits)
+        NUW = NSW = false;
+
       LinearExpression E(Val);
       switch (BOp->getOpcode()) {
       default:
@@ -462,7 +483,7 @@ static APInt adjustToPointerSize(const APInt &Offset, unsigned PointerSize) {
 
 namespace {
 // A linear transformation of a Value; this class represents
-// ZExt(SExt(V, SExtBits), ZExtBits) * Scale.
+// ZExt(SExt(Trunc(V, TruncBits), SExtBits), ZExtBits) * Scale.
 struct VariableGEPIndex {
   CastedValue Val;
   APInt Scale;
@@ -481,6 +502,7 @@ struct VariableGEPIndex {
     OS << "(V=" << Val.V->getName()
        << ", zextbits=" << Val.ZExtBits
        << ", sextbits=" << Val.SExtBits
+       << ", truncbits=" << Val.TruncBits
        << ", scale=" << Scale << ")";
   }
 };
@@ -638,8 +660,9 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL,
       // sign extended to pointer size.
       unsigned Width = Index->getType()->getIntegerBitWidth();
       unsigned SExtBits = PointerSize > Width ? PointerSize - Width : 0;
+      unsigned TruncBits = PointerSize < Width ? Width - PointerSize : 0;
       LinearExpression LE = GetLinearExpression(
-          CastedValue(Index, 0, SExtBits), DL, 0, AC, DT);
+          CastedValue(Index, 0, SExtBits, TruncBits), DL, 0, AC, DT);
 
       // The GEP index scale ("Scale") scales C1*V+C2, yielding (C1*V+C2)*Scale.
       // This gives us an aggregate computation of (C1*Scale)*V + C2*Scale.
@@ -655,7 +678,7 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL,
       APInt ScaledOffset = LE.Offset.sextOrTrunc(MaxPointerSize)
                            .smul_ov(Scale, Overflow);
       if (Overflow) {
-        LE = LinearExpression(CastedValue(Index, 0, SExtBits));
+        LE = LinearExpression(CastedValue(Index, 0, SExtBits, TruncBits));
       } else {
         Decomposed.Offset += ScaledOffset;
         Scale *= LE.Scale.sextOrTrunc(MaxPointerSize);
@@ -1245,7 +1268,6 @@ AliasResult BasicAAResult::aliasGEP(
       if (AllNonNegative || AllNonPositive) {
         KnownBits Known = Index.Val.evaluateWith(
             computeKnownBits(Index.Val.V, DL, 0, &AC, Index.CxtI, DT));
-        // TODO: Account for implicit trunc.
         bool SignKnownZero = Known.isNonNegative();
         bool SignKnownOne = Known.isNegative();
         AllNonNegative &= (SignKnownZero && Scale.isNonNegative()) ||
@@ -1294,7 +1316,8 @@ AliasResult BasicAAResult::aliasGEP(
       if (DecompGEP1.VarIndices.size() == 1) {
         // VarIndex = Scale*V.
         const VariableGEPIndex &Var = DecompGEP1.VarIndices[0];
-        if (isKnownNonZero(Var.Val.V, DL, 0, &AC, Var.CxtI, DT)) {
+        if (Var.Val.TruncBits == 0 &&
+            isKnownNonZero(Var.Val.V, DL, 0, &AC, Var.CxtI, DT)) {
           // If V != 0 then abs(VarIndex) >= abs(Scale).
           MinAbsVarIndex = Var.Scale.abs();
         }
@@ -1310,7 +1333,7 @@ AliasResult BasicAAResult::aliasGEP(
         // inequality of values across loop iterations.
         const VariableGEPIndex &Var0 = DecompGEP1.VarIndices[0];
         const VariableGEPIndex &Var1 = DecompGEP1.VarIndices[1];
-        if (Var0.Scale == -Var1.Scale &&
+        if (Var0.Scale == -Var1.Scale && Var0.Val.TruncBits == 0 &&
             Var0.Val.hasSameCastsAs(Var1.Val) && VisitedPhiBBs.empty() &&
             isKnownNonEqual(Var0.Val.V, Var1.Val.V, DL, &AC, /* CxtI */ nullptr,
                             DT))
@@ -1835,7 +1858,8 @@ bool BasicAAResult::constantOffsetHeuristic(
 
   const VariableGEPIndex &Var0 = GEP.VarIndices[0], &Var1 = GEP.VarIndices[1];
 
-  if (!Var0.Val.hasSameCastsAs(Var1.Val) || Var0.Scale != -Var1.Scale ||
+  if (Var0.Val.TruncBits != 0 || !Var0.Val.hasSameCastsAs(Var1.Val) ||
+      Var0.Scale != -Var1.Scale ||
       Var0.Val.V->getType() != Var1.Val.V->getType())
     return false;
 
index 1e9c729..09197b6 100644 (file)
@@ -44,15 +44,14 @@ define void @noalias_overflow_in_32_bit_constants(i8* %ptr) {
   ret void
 }
 
-; FIXME: Currently we incorrectly determine NoAlias for %gep.1 and %gep.2. The
-; GEP indices get implicitly truncated to 32 bit, so multiples of 2^32
+; The GEP indices get implicitly truncated to 32 bit, so multiples of 2^32
 ; (=4294967296) will be 0.
 ; See https://alive2.llvm.org/ce/z/HHjQgb.
 define void @mustalias_overflow_in_32_bit_add_mul_gep(i8* %ptr, i64 %i) {
 ; CHECK-LABEL: Function: mustalias_overflow_in_32_bit_add_mul_gep: 3 pointers, 1 call sites
-; CHECK-NEXT:    NoAlias:  i8* %gep.1, i8* %ptr
-; CHECK-NEXT:    NoAlias:  i8* %gep.2, i8* %ptr
-; CHECK-NEXT:    NoAlias:  i8* %gep.1, i8* %gep.2
+; CHECK-NEXT:    MayAlias: i8* %gep.1, i8* %ptr
+; CHECK-NEXT:    MayAlias: i8* %gep.2, i8* %ptr
+; CHECK-NEXT:    MayAlias: i8* %gep.1, i8* %gep.2
 ;
   %s.1 = icmp sgt i64 %i, 0
   call void @llvm.assume(i1 %s.1)
@@ -66,10 +65,9 @@ define void @mustalias_overflow_in_32_bit_add_mul_gep(i8* %ptr, i64 %i) {
   ret void
 }
 
-; FIXME: While %n is non-zero, its low 32 bits may not be.
 define void @mayalias_overflow_in_32_bit_non_zero(i8* %ptr, i64 %n) {
 ; CHECK-LABEL: Function: mayalias_overflow_in_32_bit_non_zero
-; CHECK:    NoAlias: i8* %gep, i8* %ptr
+; CHECK:    MayAlias: i8* %gep, i8* %ptr
 ;
   %c = icmp ne i64 %n, 0
   call void @llvm.assume(i1 %c)
@@ -79,12 +77,11 @@ define void @mayalias_overflow_in_32_bit_non_zero(i8* %ptr, i64 %n) {
   ret void
 }
 
-; FIXME: While %n is positive, its low 32 bits may not be.
 define void @mayalias_overflow_in_32_bit_positive(i8* %ptr, i64 %n) {
 ; CHECK-LABEL: Function: mayalias_overflow_in_32_bit_positive
 ; CHECK:    NoAlias: i8* %gep.1, i8* %ptr
-; CHECK:    NoAlias: i8* %gep.2, i8* %ptr
-; CHECK:    NoAlias: i8* %gep.1, i8* %gep.2
+; CHECK:    MayAlias: i8* %gep.2, i8* %ptr
+; CHECK:    MayAlias: i8* %gep.1, i8* %gep.2
 ;
   %c = icmp sgt i64 %n, 0
   call void @llvm.assume(i1 %c)