[InstCombine] Don't rewrite phi-of-bitcast when the phi has other users
authorConnor Abbott <cwabbott0@gmail.com>
Tue, 31 Dec 2019 11:11:35 +0000 (12:11 +0100)
committerNikita Popov <nikita.ppv@gmail.com>
Tue, 31 Dec 2019 11:15:02 +0000 (12:15 +0100)
Judging by the existing comments, this was the intention, but the
transform never actually checked if the existing phi's would be removed.
See https://bugs.llvm.org/show_bug.cgi?id=44242 for an example where
this causes much worse code generation on AMDGPU.

Differential Revision: https://reviews.llvm.org/D71209

llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
llvm/test/Transforms/InstCombine/pr44242.ll

index e16919e..3ba56bb 100644 (file)
@@ -2277,6 +2277,31 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) {
     }
   }
 
+  // Check that each user of each old PHI node is something that we can
+  // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
+  for (auto *OldPN : OldPhiNodes) {
+    for (User *V : OldPN->users()) {
+      if (auto *SI = dyn_cast<StoreInst>(V)) {
+        if (!SI->isSimple() || SI->getOperand(0) != OldPN)
+          return nullptr;
+      } else if (auto *BCI = dyn_cast<BitCastInst>(V)) {
+        // Verify it's a B->A cast.
+        Type *TyB = BCI->getOperand(0)->getType();
+        Type *TyA = BCI->getType();
+        if (TyA != DestTy || TyB != SrcTy)
+          return nullptr;
+      } else if (auto *PHI = dyn_cast<PHINode>(V)) {
+        // As long as the user is another old PHI node, then even if we don't
+        // rewrite it, the PHI web we're considering won't have any users
+        // outside itself, so it'll be dead.
+        if (OldPhiNodes.count(PHI) == 0)
+          return nullptr;
+      } else {
+        return nullptr;
+      }
+    }
+  }
+
   // For each old PHI node, create a corresponding new PHI node with a type A.
   SmallDenseMap<PHINode *, PHINode *> NewPNodes;
   for (auto *OldPN : OldPhiNodes) {
@@ -2321,24 +2346,28 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) {
     PHINode *NewPN = NewPNodes[OldPN];
     for (User *V : OldPN->users()) {
       if (auto *SI = dyn_cast<StoreInst>(V)) {
-        if (SI->isSimple() && SI->getOperand(0) == OldPN) {
-          Builder.SetInsertPoint(SI);
-          auto *NewBC =
-            cast<BitCastInst>(Builder.CreateBitCast(NewPN, SrcTy));
-          SI->setOperand(0, NewBC);
-          Worklist.Add(SI);
-          assert(hasStoreUsersOnly(*NewBC));
-        }
+        assert(SI->isSimple() && SI->getOperand(0) == OldPN);
+        Builder.SetInsertPoint(SI);
+        auto *NewBC =
+          cast<BitCastInst>(Builder.CreateBitCast(NewPN, SrcTy));
+        SI->setOperand(0, NewBC);
+        Worklist.Add(SI);
+        assert(hasStoreUsersOnly(*NewBC));
       }
       else if (auto *BCI = dyn_cast<BitCastInst>(V)) {
-        // Verify it's a B->A cast.
         Type *TyB = BCI->getOperand(0)->getType();
         Type *TyA = BCI->getType();
-        if (TyA == DestTy && TyB == SrcTy) {
-          Instruction *I = replaceInstUsesWith(*BCI, NewPN);
-          if (BCI == &CI)
-            RetVal = I;
-        }
+        assert(TyA == DestTy && TyB == SrcTy);
+        (void) TyA;
+        (void) TyB;
+        Instruction *I = replaceInstUsesWith(*BCI, NewPN);
+        if (BCI == &CI)
+          RetVal = I;
+      } else if (auto *PHI = dyn_cast<PHINode>(V)) {
+        assert(OldPhiNodes.count(PHI) > 0);
+        (void) PHI;
+      } else {
+        llvm_unreachable("all uses should be handled");
       }
     }
   }
index e478193..5e783af 100644 (file)
@@ -10,17 +10,17 @@ define float @sitofp(float %x) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br label [[LOOP_HEADER:%.*]]
 ; CHECK:       loop_header:
-; CHECK-NEXT:    [[TMP0:%.*]] = phi float [ 0.000000e+00, [[ENTRY:%.*]] ], [ [[VAL_INCR:%.*]], [[LOOP:%.*]] ]
-; CHECK-NEXT:    [[VAL:%.*]] = phi float [ 0.000000e+00, [[ENTRY]] ], [ [[PHITMP:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp ogt float [[TMP0]], [[X:%.*]]
+; CHECK-NEXT:    [[VAL:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[VAL_INCR_CASTED:%.*]], [[LOOP:%.*]] ]
+; CHECK-NEXT:    [[VAL_CASTED:%.*]] = bitcast i32 [[VAL]] to float
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ogt float [[VAL_CASTED]], [[X:%.*]]
 ; CHECK-NEXT:    br i1 [[CMP]], label [[END:%.*]], label [[LOOP]]
 ; CHECK:       loop:
-; CHECK-NEXT:    [[VAL_INCR]] = fadd float [[TMP0]], 1.000000e+00
-; CHECK-NEXT:    [[VAL_INCR_CASTED:%.*]] = bitcast float [[VAL_INCR]] to i32
-; CHECK-NEXT:    [[PHITMP]] = sitofp i32 [[VAL_INCR_CASTED]] to float
+; CHECK-NEXT:    [[VAL_INCR:%.*]] = fadd float [[VAL_CASTED]], 1.000000e+00
+; CHECK-NEXT:    [[VAL_INCR_CASTED]] = bitcast float [[VAL_INCR]] to i32
 ; CHECK-NEXT:    br label [[LOOP_HEADER]]
 ; CHECK:       end:
-; CHECK-NEXT:    ret float [[VAL]]
+; CHECK-NEXT:    [[RESULT:%.*]] = sitofp i32 [[VAL]] to float
+; CHECK-NEXT:    ret float [[RESULT]]
 ;
 entry:
   br label %loop_header
@@ -44,16 +44,17 @@ define <2 x i16> @bitcast(float %x) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br label [[LOOP_HEADER:%.*]]
 ; CHECK:       loop_header:
-; CHECK-NEXT:    [[TMP0:%.*]] = phi float [ 0.000000e+00, [[ENTRY:%.*]] ], [ [[VAL_INCR:%.*]], [[LOOP:%.*]] ]
-; CHECK-NEXT:    [[VAL:%.*]] = phi <2 x i16> [ zeroinitializer, [[ENTRY]] ], [ [[PHITMP:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp ogt float [[TMP0]], [[X:%.*]]
+; CHECK-NEXT:    [[VAL:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[VAL_INCR_CASTED:%.*]], [[LOOP:%.*]] ]
+; CHECK-NEXT:    [[VAL_CASTED:%.*]] = bitcast i32 [[VAL]] to float
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ogt float [[VAL_CASTED]], [[X:%.*]]
 ; CHECK-NEXT:    br i1 [[CMP]], label [[END:%.*]], label [[LOOP]]
 ; CHECK:       loop:
-; CHECK-NEXT:    [[VAL_INCR]] = fadd float [[TMP0]], 1.000000e+00
-; CHECK-NEXT:    [[PHITMP]] = bitcast float [[VAL_INCR]] to <2 x i16>
+; CHECK-NEXT:    [[VAL_INCR:%.*]] = fadd float [[VAL_CASTED]], 1.000000e+00
+; CHECK-NEXT:    [[VAL_INCR_CASTED]] = bitcast float [[VAL_INCR]] to i32
 ; CHECK-NEXT:    br label [[LOOP_HEADER]]
 ; CHECK:       end:
-; CHECK-NEXT:    ret <2 x i16> [[VAL]]
+; CHECK-NEXT:    [[RESULT:%.*]] = bitcast i32 [[VAL]] to <2 x i16>
+; CHECK-NEXT:    ret <2 x i16> [[RESULT]]
 ;
 entry:
   br label %loop_header
@@ -79,12 +80,12 @@ define void @store_volatile(float %x) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br label [[LOOP_HEADER:%.*]]
 ; CHECK:       loop_header:
-; CHECK-NEXT:    [[TMP0:%.*]] = phi float [ 0.000000e+00, [[ENTRY:%.*]] ], [ [[VAL_INCR:%.*]], [[LOOP:%.*]] ]
-; CHECK-NEXT:    [[VAL:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[VAL_INCR_CASTED:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp ogt float [[TMP0]], [[X:%.*]]
+; CHECK-NEXT:    [[VAL:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[VAL_INCR_CASTED:%.*]], [[LOOP:%.*]] ]
+; CHECK-NEXT:    [[VAL_CASTED:%.*]] = bitcast i32 [[VAL]] to float
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ogt float [[VAL_CASTED]], [[X:%.*]]
 ; CHECK-NEXT:    br i1 [[CMP]], label [[END:%.*]], label [[LOOP]]
 ; CHECK:       loop:
-; CHECK-NEXT:    [[VAL_INCR]] = fadd float [[TMP0]], 1.000000e+00
+; CHECK-NEXT:    [[VAL_INCR:%.*]] = fadd float [[VAL_CASTED]], 1.000000e+00
 ; CHECK-NEXT:    [[VAL_INCR_CASTED]] = bitcast float [[VAL_INCR]] to i32
 ; CHECK-NEXT:    br label [[LOOP_HEADER]]
 ; CHECK:       end:
@@ -113,13 +114,11 @@ define void @store_address(i32 %x) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br label [[LOOP_HEADER:%.*]]
 ; CHECK:       loop_header:
-; CHECK-NEXT:    [[TMP0:%.*]] = phi float* [ bitcast (i32* @global to float*), [[ENTRY:%.*]] ], [ [[VAL_INCR:%.*]], [[LOOP:%.*]] ]
-; CHECK-NEXT:    [[VAL:%.*]] = phi i32* [ @global, [[ENTRY]] ], [ [[VAL_INCR_CASTED:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[VAL:%.*]] = phi i32* [ @global, [[ENTRY:%.*]] ], [ [[VAL_INCR1:%.*]], [[LOOP:%.*]] ]
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[X:%.*]], 0
 ; CHECK-NEXT:    br i1 [[CMP]], label [[END:%.*]], label [[LOOP]]
 ; CHECK:       loop:
-; CHECK-NEXT:    [[VAL_INCR]] = getelementptr float, float* [[TMP0]], i64 1
-; CHECK-NEXT:    [[VAL_INCR_CASTED]] = bitcast float* [[VAL_INCR]] to i32*
+; CHECK-NEXT:    [[VAL_INCR1]] = getelementptr i32, i32* [[VAL]], i64 1
 ; CHECK-NEXT:    br label [[LOOP_HEADER]]
 ; CHECK:       end:
 ; CHECK-NEXT:    store i32 0, i32* [[VAL]], align 4
@@ -150,19 +149,18 @@ define i32 @multiple_phis(float %x) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br label [[LOOP_HEADER:%.*]]
 ; CHECK:       loop_header:
-; CHECK-NEXT:    [[TMP0:%.*]] = phi float [ 0.000000e+00, [[ENTRY:%.*]] ], [ [[TMP1:%.*]], [[LOOP_END:%.*]] ]
-; CHECK-NEXT:    [[VAL:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[VAL2:%.*]], [[LOOP_END]] ]
-; CHECK-NEXT:    [[CMP:%.*]] = fcmp ogt float [[TMP0]], [[X:%.*]]
+; CHECK-NEXT:    [[VAL:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[VAL2:%.*]], [[LOOP_END:%.*]] ]
+; CHECK-NEXT:    [[VAL_CASTED:%.*]] = bitcast i32 [[VAL]] to float
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ogt float [[VAL_CASTED]], [[X:%.*]]
 ; CHECK-NEXT:    br i1 [[CMP]], label [[END:%.*]], label [[LOOP:%.*]]
 ; CHECK:       loop:
-; CHECK-NEXT:    [[CMP2:%.*]] = fcmp ogt float [[TMP0]], 2.000000e+00
+; CHECK-NEXT:    [[CMP2:%.*]] = fcmp ogt float [[VAL_CASTED]], 2.000000e+00
 ; CHECK-NEXT:    br i1 [[CMP2]], label [[IF:%.*]], label [[LOOP_END]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[VAL_INCR:%.*]] = fadd float [[TMP0]], 1.000000e+00
+; CHECK-NEXT:    [[VAL_INCR:%.*]] = fadd float [[VAL_CASTED]], 1.000000e+00
 ; CHECK-NEXT:    [[VAL_INCR_CASTED:%.*]] = bitcast float [[VAL_INCR]] to i32
 ; CHECK-NEXT:    br label [[LOOP_END]]
 ; CHECK:       loop_end:
-; CHECK-NEXT:    [[TMP1]] = phi float [ [[TMP0]], [[LOOP]] ], [ [[VAL_INCR]], [[IF]] ]
 ; CHECK-NEXT:    [[VAL2]] = phi i32 [ [[VAL]], [[LOOP]] ], [ [[VAL_INCR_CASTED]], [[IF]] ]
 ; CHECK-NEXT:    store volatile i32 [[VAL2]], i32* @global, align 4
 ; CHECK-NEXT:    br label [[LOOP_HEADER]]