[InstCombine] fold icmp of truncated left shift, part 2
authorSanjay Patel <spatel@rotateright.com>
Thu, 8 Sep 2022 16:09:26 +0000 (12:09 -0400)
committerSanjay Patel <spatel@rotateright.com>
Thu, 8 Sep 2022 16:44:02 +0000 (12:44 -0400)
(trunc (1 << Y) to iN) == 2**C --> Y == C
(trunc (1 << Y) to iN) != 2**C --> Y != C
https://alive2.llvm.org/ce/z/xnFPo5

Follow-up to d9e1f9d7591b0d3e4d. This was a suggested
enhancement mentioned in issue #51889.

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
llvm/test/Transforms/InstCombine/icmp-trunc.ll

index 35b9e37..0969e9f 100644 (file)
@@ -1552,15 +1552,20 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
   unsigned DstBits = Trunc->getType()->getScalarSizeInBits(),
            SrcBits = SrcTy->getScalarSizeInBits();
 
-  // (trunc (1 << Y) to iN) == 0 --> Y u>= N
-  // (trunc (1 << Y) to iN) != 0 --> Y u<  N
   // TODO: Handle any shifted constant by subtracting trailing zeros.
   // TODO: Handle non-equality predicates.
-  // TODO: Handle compare to power-of-2 (non-zero) constant.
   Value *Y;
-  if (Cmp.isEquality() && C.isZero() && match(X, m_Shl(m_One(), m_Value(Y)))) {
-    auto NewPred = Pred == Cmp.ICMP_EQ ? Cmp.ICMP_UGE : Cmp.ICMP_ULT;
-    return new ICmpInst(NewPred, Y, ConstantInt::get(SrcTy, DstBits));
+  if (Cmp.isEquality() && match(X, m_Shl(m_One(), m_Value(Y)))) {
+    // (trunc (1 << Y) to iN) == 0 --> Y u>= N
+    // (trunc (1 << Y) to iN) != 0 --> Y u<  N
+    if (C.isZero()) {
+      auto NewPred = (Pred == Cmp.ICMP_EQ) ? Cmp.ICMP_UGE : Cmp.ICMP_ULT;
+      return new ICmpInst(NewPred, Y, ConstantInt::get(SrcTy, DstBits));
+    }
+    // (trunc (1 << Y) to iN) == 2**C --> Y == C
+    // (trunc (1 << Y) to iN) != 2**C --> Y != C
+    if (C.isPowerOf2())
+      return new ICmpInst(Pred, Y, ConstantInt::get(SrcTy, C.logBase2()));
   }
 
   if (Cmp.isEquality() && Trunc->hasOneUse()) {
index 5d4ccf7..fdf06b7 100644 (file)
@@ -436,14 +436,12 @@ define i1 @shl1_trunc_sgt0(i9 %a) {
   ret i1 %r
 }
 
-; TODO: A == 0
-
 define i1 @shl1_trunc_eq1(i64 %a) {
 ; CHECK-LABEL: @shl1_trunc_eq1(
 ; CHECK-NEXT:    [[SHL:%.*]] = shl i64 1, [[A:%.*]]
 ; CHECK-NEXT:    [[T:%.*]] = trunc i64 [[SHL]] to i8
 ; CHECK-NEXT:    call void @use(i8 [[T]])
-; CHECK-NEXT:    [[R:%.*]] = icmp eq i8 [[T]], 1
+; CHECK-NEXT:    [[R:%.*]] = icmp eq i64 [[A]], 0
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %shl = shl i64 1, %a
@@ -453,14 +451,11 @@ define i1 @shl1_trunc_eq1(i64 %a) {
   ret i1 %r
 }
 
-; TODO: A != 5
-
 define i1 @shl1_trunc_ne32(i8 %a) {
 ; CHECK-LABEL: @shl1_trunc_ne32(
 ; CHECK-NEXT:    [[SHL:%.*]] = shl i8 1, [[A:%.*]]
 ; CHECK-NEXT:    call void @use(i8 [[SHL]])
-; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[SHL]], 63
-; CHECK-NEXT:    [[R:%.*]] = icmp ne i8 [[TMP1]], 32
+; CHECK-NEXT:    [[R:%.*]] = icmp ne i8 [[A]], 5
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %shl = shl i8 1, %a