[InstCombine] Fold IntToPtr/PtrToInt to bitcast
authorKrishna Kariya <krishna17060@iiitd.ac.in>
Sun, 18 Jul 2021 21:13:25 +0000 (23:13 +0200)
committerNikita Popov <nikita.ppv@gmail.com>
Sun, 18 Jul 2021 21:13:25 +0000 (23:13 +0200)
The inttoptr/ptrtoint roundtrip optimization is not always correct.
We are working towards removing this optimization and adding support
to specific cases where this optimization works. This patch is the
first one on this line.

Consider the example:

    %i = ptrtoint i8* %X to i64
    %p = inttoptr i64 %i to i16*
    %cmp = icmp eq i8* %load, %p

In this specific case, the inttoptr/ptrtoint optimization is correct
as it only compares the pointer values. In this patch, we fold
inttoptr/ptrtoint to a bitcast (if src and dest types are different).

Differential Revision: https://reviews.llvm.org/D105088

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
llvm/lib/Transforms/InstCombine/InstCombineInternal.h
llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
llvm/test/Transforms/InstCombine/ptr-int-ptr-icmp.ll

index 6de4c4d..2b0ef0c 100644 (file)
@@ -4571,6 +4571,16 @@ static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp,
 
 /// Handle icmp (cast x), (cast or constant).
 Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) {
+  // If any operand of ICmp is a inttoptr roundtrip cast then remove it as
+  // icmp compares only pointer's value.
+  // icmp (inttoptr (ptrtoint p1)), p2 --> icmp p1, p2.
+  Value *SimplifiedOp0 = simplifyIntToPtrRoundTripCast(ICmp.getOperand(0));
+  Value *SimplifiedOp1 = simplifyIntToPtrRoundTripCast(ICmp.getOperand(1));
+  if (SimplifiedOp0 || SimplifiedOp1)
+    return new ICmpInst(ICmp.getPredicate(),
+                        SimplifiedOp0 ? SimplifiedOp0 : ICmp.getOperand(0),
+                        SimplifiedOp1 ? SimplifiedOp1 : ICmp.getOperand(1));
+
   auto *CastOp0 = dyn_cast<CastInst>(ICmp.getOperand(0));
   if (!CastOp0)
     return nullptr;
index 7655280..6d0690d 100644 (file)
@@ -340,6 +340,7 @@ private:
   /// \see CastInst::isEliminableCastPair
   Instruction::CastOps isEliminableCastPair(const CastInst *CI1,
                                             const CastInst *CI2);
+  Value *simplifyIntToPtrRoundTripCast(Value *Val);
 
   Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &And);
   Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &Or);
index 9646a8c..07b354f 100644 (file)
@@ -346,6 +346,25 @@ static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1,
   return true;
 }
 
+// Simplifies IntToPtr/PtrToInt RoundTrip Cast To BitCast.
+// inttoptr ( ptrtoint (x) ) --> x
+Value *InstCombinerImpl::simplifyIntToPtrRoundTripCast(Value *Val) {
+  auto *IntToPtr = dyn_cast<IntToPtrInst>(Val);
+  if (IntToPtr && DL.getPointerTypeSizeInBits(IntToPtr->getDestTy()) ==
+                      DL.getTypeSizeInBits(IntToPtr->getSrcTy())) {
+    auto *PtrToInt = dyn_cast<PtrToIntInst>(IntToPtr->getOperand(0));
+    Type *CastTy = IntToPtr->getDestTy();
+    if (PtrToInt &&
+        CastTy->getPointerAddressSpace() ==
+            PtrToInt->getSrcTy()->getPointerAddressSpace() &&
+        DL.getPointerTypeSizeInBits(PtrToInt->getSrcTy()) ==
+            DL.getTypeSizeInBits(PtrToInt->getDestTy())) {
+      return Builder.CreateBitCast(PtrToInt->getOperand(0), CastTy);
+    }
+  }
+  return nullptr;
+}
+
 /// This performs a few simplifications for operators that are associative or
 /// commutative:
 ///
index 9784f18..f99685c 100644 (file)
@@ -8,9 +8,7 @@ target triple = "x86_64-unknown-linux-gnu"
 
 define i1 @func(i8* %X, i8* %Y) {
 ; CHECK-LABEL: @func(
-; CHECK-NEXT:    [[I:%.*]] = ptrtoint i8* [[X:%.*]] to i64
-; CHECK-NEXT:    [[P:%.*]] = inttoptr i64 [[I]] to i8*
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8* [[P]], [[Y:%.*]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8* [[X:%.*]], [[Y:%.*]]
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %i = ptrtoint i8* %X to i64
@@ -21,9 +19,8 @@ define i1 @func(i8* %X, i8* %Y) {
 
 define i1 @func_pointer_different_types(i16* %X, i8* %Y) {
 ; CHECK-LABEL: @func_pointer_different_types(
-; CHECK-NEXT:    [[I:%.*]] = ptrtoint i16* [[X:%.*]] to i64
-; CHECK-NEXT:    [[P:%.*]] = inttoptr i64 [[I]] to i8*
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8* [[P]], [[Y:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i16* [[X:%.*]] to i8*
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8* [[TMP1]], [[Y:%.*]]
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %i = ptrtoint i16* %X to i64
@@ -37,9 +34,8 @@ declare i8* @gen8ptr()
 define i1 @func_commutative(i16* %X) {
 ; CHECK-LABEL: @func_commutative(
 ; CHECK-NEXT:    [[Y:%.*]] = call i8* @gen8ptr()
-; CHECK-NEXT:    [[I:%.*]] = ptrtoint i16* [[X:%.*]] to i64
-; CHECK-NEXT:    [[P:%.*]] = inttoptr i64 [[I]] to i8*
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8* [[Y]], [[P]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i16* [[X:%.*]] to i8*
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8* [[Y]], [[TMP1]]
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %Y = call i8* @gen8ptr() ; thwart complexity-based canonicalization