[SCEVExpander] Add option to preserve LCSSA directly.
authorFlorian Hahn <flo@fhahn.com>
Wed, 29 Jul 2020 13:54:03 +0000 (14:54 +0100)
committerFlorian Hahn <flo@fhahn.com>
Wed, 29 Jul 2020 14:07:37 +0000 (15:07 +0100)
This patch teaches SCEVExpander to directly preserve LCSSA.

As it is currently, SCEV does not look through PHI nodes in loops,
as it might break LCSSA form. Once SCEVExpander can preserve
LCSSA form, it should be safe for SCEV to look through PHIs.

To preserve LCSSA form, this patch uses formLCSSAForInstructions
on operands of newly created instructions, if the definition is inside
a different loop than the new instruction.

The final value we return from expandCodeFor may also need LCSSA
phis, depending on the insert point. As no user for it exists there yet,
create a temporary instruction at the insert point, which can be passed
to formLCSSAForInstructions. This temporary instruction is removed
after LCSSA construction.

Reviewed By: mkazantsev

Differential Revision: https://reviews.llvm.org/D71538

llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp

index 6e53d15..cb212b2 100644 (file)
@@ -52,6 +52,9 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
   // New instructions receive a name to identify them with the current pass.
   const char *IVName;
 
+  /// Indicates whether LCSSA phis should be created for inserted values.
+  bool PreserveLCSSA;
+
   // InsertedExpressions caches Values for reuse, so must track RAUW.
   DenseMap<std::pair<const SCEV *, Instruction *>, TrackingVH<Value>>
       InsertedExpressions;
@@ -146,9 +149,10 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
 public:
   /// Construct a SCEVExpander in "canonical" mode.
   explicit SCEVExpander(ScalarEvolution &se, const DataLayout &DL,
-                        const char *name)
-      : SE(se), DL(DL), IVName(name), IVIncInsertLoop(nullptr),
-        IVIncInsertPos(nullptr), CanonicalMode(true), LSRMode(false),
+                        const char *name, bool PreserveLCSSA = true)
+      : SE(se), DL(DL), IVName(name), PreserveLCSSA(PreserveLCSSA),
+        IVIncInsertLoop(nullptr), IVIncInsertPos(nullptr), CanonicalMode(true),
+        LSRMode(false),
         Builder(se.getContext(), TargetFolder(DL),
                 IRBuilderCallbackInserter(
                     [this](Instruction *I) { rememberInstruction(I); })) {
@@ -223,14 +227,18 @@ public:
                                const TargetTransformInfo *TTI = nullptr);
 
   /// Insert code to directly compute the specified SCEV expression into the
-  /// program.  The inserted code is inserted into the specified block.
-  Value *expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I);
+  /// program.  The code is inserted into the specified block.
+  Value *expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I) {
+    return expandCodeForImpl(SH, Ty, I, true);
+  }
 
   /// Insert code to directly compute the specified SCEV expression into the
-  /// program.  The inserted code is inserted into the SCEVExpander's current
+  /// program.  The code is inserted into the SCEVExpander's current
   /// insertion point. If a type is specified, the result will be expanded to
   /// have that type, with a cast if necessary.
-  Value *expandCodeFor(const SCEV *SH, Type *Ty = nullptr);
+  Value *expandCodeFor(const SCEV *SH, Type *Ty = nullptr) {
+    return expandCodeForImpl(SH, Ty, true);
+  }
 
   /// Generates a code sequence that evaluates this predicate.  The inserted
   /// instructions will be at position \p Loc.  The result will be of type i1
@@ -338,6 +346,20 @@ public:
 private:
   LLVMContext &getContext() const { return SE.getContext(); }
 
+  /// Insert code to directly compute the specified SCEV expression into the
+  /// program. The code is inserted into the SCEVExpander's current
+  /// insertion point. If a type is specified, the result will be expanded to
+  /// have that type, with a cast if necessary. If \p Root is true, this
+  /// indicates that \p SH is the top-level expression to expand passed from
+  /// an external client call.
+  Value *expandCodeForImpl(const SCEV *SH, Type *Ty, bool Root);
+
+  /// Insert code to directly compute the specified SCEV expression into the
+  /// program. The code is inserted into the specified block. If \p
+  /// Root is true, this indicates that \p SH is the top-level expression to
+  /// expand passed from an external client call.
+  Value *expandCodeForImpl(const SCEV *SH, Type *Ty, Instruction *I, bool Root);
+
   /// Recursive helper function for isHighCostExpansion.
   bool isHighCostExpansionHelper(const SCEV *S, Loop *L, const Instruction &At,
                                  int &BudgetRemaining,
@@ -419,6 +441,11 @@ private:
                       Instruction *Pos, PHINode *LoopPhi);
 
   void fixupInsertPoints(Instruction *I);
+
+  /// If required, create LCSSA PHIs for \p Users' operand \p OpIdx. If new
+  /// LCSSA PHIs have been created, return the LCSSA PHI available at \p User.
+  /// If no PHIs have been created, return the unchanged operand \p OpIdx.
+  Value *fixupLCSSAFormFor(Instruction *User, unsigned OpIdx);
 };
 } // namespace llvm
 
index cf02ef1..c3e46c1 100644 (file)
@@ -5514,8 +5514,8 @@ void LSRInstance::ImplementSolution(
   // we can remove them after we are done working.
   SmallVector<WeakTrackingVH, 16> DeadInsts;
 
-  SCEVExpander Rewriter(SE, L->getHeader()->getModule()->getDataLayout(),
-                        "lsr");
+  SCEVExpander Rewriter(SE, L->getHeader()->getModule()->getDataLayout(), "lsr",
+                        false);
 #ifndef NDEBUG
   Rewriter.setDebugType(DEBUG_TYPE);
 #endif
@@ -5780,7 +5780,7 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
   if (EnablePhiElim && L->isLoopSimplifyForm()) {
     SmallVector<WeakTrackingVH, 16> DeadInsts;
     const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
-    SCEVExpander Rewriter(SE, DL, "lsr");
+    SCEVExpander Rewriter(SE, DL, "lsr", false);
 #ifndef NDEBUG
     Rewriter.setDebugType(DEBUG_TYPE);
 #endif
index c99d612..1a10e58 100644 (file)
@@ -27,6 +27,7 @@
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Utils/LoopUtils.h"
 
 using namespace llvm;
 
@@ -461,9 +462,10 @@ Value *SCEVExpander::expandAddToGEP(const SCEV *const *op_begin,
     // we didn't find any operands that could be factored, tentatively
     // assume that element zero was selected (since the zero offset
     // would obviously be folded away).
-    Value *Scaled = ScaledOps.empty() ?
-                    Constant::getNullValue(Ty) :
-                    expandCodeFor(SE.getAddExpr(ScaledOps), Ty);
+    Value *Scaled =
+        ScaledOps.empty()
+            ? Constant::getNullValue(Ty)
+            : expandCodeForImpl(SE.getAddExpr(ScaledOps), Ty, false);
     GepIndices.push_back(Scaled);
 
     // Collect struct field index operands.
@@ -522,7 +524,7 @@ Value *SCEVExpander::expandAddToGEP(const SCEV *const *op_begin,
            SE.DT.dominates(cast<Instruction>(V), &*Builder.GetInsertPoint()));
 
     // Expand the operands for a plain byte offset.
-    Value *Idx = expandCodeFor(SE.getAddExpr(Ops), Ty);
+    Value *Idx = expandCodeForImpl(SE.getAddExpr(Ops), Ty, false);
 
     // Fold a GEP with constant operands.
     if (Constant *CLHS = dyn_cast<Constant>(V))
@@ -743,14 +745,14 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
       Sum = expandAddToGEP(NewOps.begin(), NewOps.end(), PTy, Ty, expand(Op));
     } else if (Op->isNonConstantNegative()) {
       // Instead of doing a negate and add, just do a subtract.
-      Value *W = expandCodeFor(SE.getNegativeSCEV(Op), Ty);
+      Value *W = expandCodeForImpl(SE.getNegativeSCEV(Op), Ty, false);
       Sum = InsertNoopCastOfTo(Sum, Ty);
       Sum = InsertBinop(Instruction::Sub, Sum, W, SCEV::FlagAnyWrap,
                         /*IsSafeToHoist*/ true);
       ++I;
     } else {
       // A simple add.
-      Value *W = expandCodeFor(Op, Ty);
+      Value *W = expandCodeForImpl(Op, Ty, false);
       Sum = InsertNoopCastOfTo(Sum, Ty);
       // Canonicalize a constant to the RHS.
       if (isa<Constant>(Sum)) std::swap(Sum, W);
@@ -802,7 +804,7 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) {
 
     // Calculate powers with exponents 1, 2, 4, 8 etc. and include those of them
     // that are needed into the result.
-    Value *P = expandCodeFor(I->second, Ty);
+    Value *P = expandCodeForImpl(I->second, Ty, false);
     Value *Result = nullptr;
     if (Exponent & 1)
       Result = P;
@@ -861,7 +863,7 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) {
 Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) {
   Type *Ty = SE.getEffectiveSCEVType(S->getType());
 
-  Value *LHS = expandCodeFor(S->getLHS(), Ty);
+  Value *LHS = expandCodeForImpl(S->getLHS(), Ty, false);
   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getRHS())) {
     const APInt &RHS = SC->getAPInt();
     if (RHS.isPowerOf2())
@@ -870,7 +872,7 @@ Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) {
                          SCEV::FlagAnyWrap, /*IsSafeToHoist*/ true);
   }
 
-  Value *RHS = expandCodeFor(S->getRHS(), Ty);
+  Value *RHS = expandCodeForImpl(S->getRHS(), Ty, false);
   return InsertBinop(Instruction::UDiv, LHS, RHS, SCEV::FlagAnyWrap,
                      /*IsSafeToHoist*/ SE.isKnownNonZero(S->getRHS()));
 }
@@ -1265,8 +1267,9 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
   // Expand code for the start value into the loop preheader.
   assert(L->getLoopPreheader() &&
          "Can't expand add recurrences without a loop preheader!");
-  Value *StartV = expandCodeFor(Normalized->getStart(), ExpandTy,
-                                L->getLoopPreheader()->getTerminator());
+  Value *StartV =
+      expandCodeForImpl(Normalized->getStart(), ExpandTy,
+                        L->getLoopPreheader()->getTerminator(), false);
 
   // StartV must have been be inserted into L's preheader to dominate the new
   // phi.
@@ -1284,8 +1287,8 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
   if (useSubtract)
     Step = SE.getNegativeSCEV(Step);
   // Expand the step somewhere that dominates the loop header.
-  Value *StepV = expandCodeFor(Step, IntTy,
-                               &*L->getHeader()->getFirstInsertionPt());
+  Value *StepV = expandCodeForImpl(
+      Step, IntTy, &*L->getHeader()->getFirstInsertionPt(), false);
 
   // The no-wrap behavior proved by IsIncrement(NUW|NSW) is only applicable if
   // we actually do emit an addition.  It does not apply if we emit a
@@ -1430,8 +1433,8 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) {
       {
         // Expand the step somewhere that dominates the loop header.
         SCEVInsertPointGuard Guard(Builder, this);
-        StepV = expandCodeFor(Step, IntTy,
-                              &*L->getHeader()->getFirstInsertionPt());
+        StepV = expandCodeForImpl(
+            Step, IntTy, &*L->getHeader()->getFirstInsertionPt(), false);
       }
       Result = expandIVInc(PN, StepV, L, ExpandTy, IntTy, useSubtract);
     }
@@ -1450,8 +1453,8 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) {
 
     // Invert the result.
     if (InvertStep)
-      Result = Builder.CreateSub(expandCodeFor(Normalized->getStart(), TruncTy),
-                                 Result);
+      Result = Builder.CreateSub(
+          expandCodeForImpl(Normalized->getStart(), TruncTy, false), Result);
   }
 
   // Re-apply any non-loop-dominating scale.
@@ -1459,22 +1462,22 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) {
     assert(S->isAffine() && "Can't linearly scale non-affine recurrences.");
     Result = InsertNoopCastOfTo(Result, IntTy);
     Result = Builder.CreateMul(Result,
-                               expandCodeFor(PostLoopScale, IntTy));
+                               expandCodeForImpl(PostLoopScale, IntTy, false));
   }
 
   // Re-apply any non-loop-dominating offset.
   if (PostLoopOffset) {
     if (PointerType *PTy = dyn_cast<PointerType>(ExpandTy)) {
       if (Result->getType()->isIntegerTy()) {
-        Value *Base = expandCodeFor(PostLoopOffset, ExpandTy);
+        Value *Base = expandCodeForImpl(PostLoopOffset, ExpandTy, false);
         Result = expandAddToGEP(SE.getUnknown(Result), PTy, IntTy, Base);
       } else {
         Result = expandAddToGEP(PostLoopOffset, PTy, IntTy, Result);
       }
     } else {
       Result = InsertNoopCastOfTo(Result, IntTy);
-      Result = Builder.CreateAdd(Result,
-                                 expandCodeFor(PostLoopOffset, IntTy));
+      Result = Builder.CreateAdd(
+          Result, expandCodeForImpl(PostLoopOffset, IntTy, false));
     }
   }
 
@@ -1516,8 +1519,8 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
                                        S->getNoWrapFlags(SCEV::FlagNW)));
     BasicBlock::iterator NewInsertPt =
         findInsertPointAfter(cast<Instruction>(V), Builder.GetInsertBlock());
-    V = expandCodeFor(SE.getTruncateExpr(SE.getUnknown(V), Ty), nullptr,
-                      &*NewInsertPt);
+    V = expandCodeForImpl(SE.getTruncateExpr(SE.getUnknown(V), Ty), nullptr,
+                          &*NewInsertPt, false);
     return V;
   }
 
@@ -1632,22 +1635,25 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
 
 Value *SCEVExpander::visitTruncateExpr(const SCEVTruncateExpr *S) {
   Type *Ty = SE.getEffectiveSCEVType(S->getType());
-  Value *V = expandCodeFor(S->getOperand(),
-                           SE.getEffectiveSCEVType(S->getOperand()->getType()));
+  Value *V = expandCodeForImpl(
+      S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType()),
+      false);
   return Builder.CreateTrunc(V, Ty);
 }
 
 Value *SCEVExpander::visitZeroExtendExpr(const SCEVZeroExtendExpr *S) {
   Type *Ty = SE.getEffectiveSCEVType(S->getType());
-  Value *V = expandCodeFor(S->getOperand(),
-                           SE.getEffectiveSCEVType(S->getOperand()->getType()));
+  Value *V = expandCodeForImpl(
+      S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType()),
+      false);
   return Builder.CreateZExt(V, Ty);
 }
 
 Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) {
   Type *Ty = SE.getEffectiveSCEVType(S->getType());
-  Value *V = expandCodeFor(S->getOperand(),
-                           SE.getEffectiveSCEVType(S->getOperand()->getType()));
+  Value *V = expandCodeForImpl(
+      S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType()),
+      false);
   return Builder.CreateSExt(V, Ty);
 }
 
@@ -1662,7 +1668,7 @@ Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) {
       Ty = SE.getEffectiveSCEVType(Ty);
       LHS = InsertNoopCastOfTo(LHS, Ty);
     }
-    Value *RHS = expandCodeFor(S->getOperand(i), Ty);
+    Value *RHS = expandCodeForImpl(S->getOperand(i), Ty, false);
     Value *ICmp = Builder.CreateICmpSGT(LHS, RHS);
     Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "smax");
     LHS = Sel;
@@ -1685,7 +1691,7 @@ Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) {
       Ty = SE.getEffectiveSCEVType(Ty);
       LHS = InsertNoopCastOfTo(LHS, Ty);
     }
-    Value *RHS = expandCodeFor(S->getOperand(i), Ty);
+    Value *RHS = expandCodeForImpl(S->getOperand(i), Ty, false);
     Value *ICmp = Builder.CreateICmpUGT(LHS, RHS);
     Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "umax");
     LHS = Sel;
@@ -1708,7 +1714,7 @@ Value *SCEVExpander::visitSMinExpr(const SCEVSMinExpr *S) {
       Ty = SE.getEffectiveSCEVType(Ty);
       LHS = InsertNoopCastOfTo(LHS, Ty);
     }
-    Value *RHS = expandCodeFor(S->getOperand(i), Ty);
+    Value *RHS = expandCodeForImpl(S->getOperand(i), Ty, false);
     Value *ICmp = Builder.CreateICmpSLT(LHS, RHS);
     Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "smin");
     LHS = Sel;
@@ -1731,7 +1737,7 @@ Value *SCEVExpander::visitUMinExpr(const SCEVUMinExpr *S) {
       Ty = SE.getEffectiveSCEVType(Ty);
       LHS = InsertNoopCastOfTo(LHS, Ty);
     }
-    Value *RHS = expandCodeFor(S->getOperand(i), Ty);
+    Value *RHS = expandCodeForImpl(S->getOperand(i), Ty, false);
     Value *ICmp = Builder.CreateICmpULT(LHS, RHS);
     Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "umin");
     LHS = Sel;
@@ -1743,15 +1749,43 @@ Value *SCEVExpander::visitUMinExpr(const SCEVUMinExpr *S) {
   return LHS;
 }
 
-Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty,
-                                   Instruction *IP) {
+Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty,
+                                       Instruction *IP, bool Root) {
   setInsertPoint(IP);
-  return expandCodeFor(SH, Ty);
+  Value *V = expandCodeForImpl(SH, Ty, Root);
+  return V;
 }
 
-Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty) {
+Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty, bool Root) {
   // Expand the code for this SCEV.
   Value *V = expand(SH);
+
+  if (PreserveLCSSA) {
+    if (auto *Inst = dyn_cast<Instruction>(V)) {
+      // Create a temporary instruction to at the current insertion point, so we
+      // can hand it off to the helper to create LCSSA PHIs if required for the
+      // new use.
+      // FIXME: Ideally formLCSSAForInstructions (used in fixupLCSSAFormFor)
+      // would accept a insertion point and return an LCSSA phi for that
+      // insertion point, so there is no need to insert & remove the temporary
+      // instruction.
+      Instruction *Tmp;
+      if (Inst->getType()->isIntegerTy())
+        Tmp = cast<Instruction>(Builder.CreateAdd(Inst, Inst));
+      else {
+        assert(Inst->getType()->isPointerTy());
+        Tmp = cast<Instruction>(Builder.CreateGEP(Inst, Builder.getInt32(1)));
+      }
+      V = fixupLCSSAFormFor(Tmp, 0);
+
+      // Clean up temporary instruction.
+      InsertedValues.erase(Tmp);
+      InsertedPostIncValues.erase(Tmp);
+      Tmp->eraseFromParent();
+    }
+  }
+
+  InsertedExpressions[std::make_pair(SH, &*Builder.GetInsertPoint())] = V;
   if (Ty) {
     assert(SE.getTypeSizeInBits(Ty) == SE.getTypeSizeInBits(SH->getType()) &&
            "non-trivial casts should be done with the SCEVs directly!");
@@ -1891,10 +1925,28 @@ Value *SCEVExpander::expand(const SCEV *S) {
 }
 
 void SCEVExpander::rememberInstruction(Value *I) {
-  if (!PostIncLoops.empty())
-    InsertedPostIncValues.insert(I);
-  else
-    InsertedValues.insert(I);
+  auto DoInsert = [this](Value *V) {
+    if (!PostIncLoops.empty())
+      InsertedPostIncValues.insert(V);
+    else
+      InsertedValues.insert(V);
+  };
+  DoInsert(I);
+
+  if (!PreserveLCSSA)
+    return;
+
+  if (auto *Inst = dyn_cast<Instruction>(I)) {
+    // A new instruction has been added, which might introduce new uses outside
+    // a defining loop. Fix LCSSA from for each operand of the new instruction,
+    // if required.
+    for (unsigned OpIdx = 0, OpEnd = Inst->getNumOperands(); OpIdx != OpEnd;
+         OpIdx++) {
+      auto *V = fixupLCSSAFormFor(Inst, OpIdx);
+      if (V != I)
+        DoInsert(V);
+    }
+  }
 }
 
 /// getOrInsertCanonicalInductionVariable - This method returns the
@@ -1913,9 +1965,8 @@ SCEVExpander::getOrInsertCanonicalInductionVariable(const Loop *L,
 
   // Emit code for it.
   SCEVInsertPointGuard Guard(Builder, this);
-  PHINode *V =
-      cast<PHINode>(expandCodeFor(H, nullptr,
-                                  &*L->getHeader()->getFirstInsertionPt()));
+  PHINode *V = cast<PHINode>(expandCodeForImpl(
+      H, nullptr, &*L->getHeader()->getFirstInsertionPt(), false));
 
   return V;
 }
@@ -2315,8 +2366,10 @@ Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred,
 
 Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred,
                                           Instruction *IP) {
-  Value *Expr0 = expandCodeFor(Pred->getLHS(), Pred->getLHS()->getType(), IP);
-  Value *Expr1 = expandCodeFor(Pred->getRHS(), Pred->getRHS()->getType(), IP);
+  Value *Expr0 =
+      expandCodeForImpl(Pred->getLHS(), Pred->getLHS()->getType(), IP, false);
+  Value *Expr1 =
+      expandCodeForImpl(Pred->getRHS(), Pred->getRHS()->getType(), IP, false);
 
   Builder.SetInsertPoint(IP);
   auto *I = Builder.CreateICmpNE(Expr0, Expr1, "ident.check");
@@ -2348,15 +2401,16 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR,
 
   IntegerType *CountTy = IntegerType::get(Loc->getContext(), SrcBits);
   Builder.SetInsertPoint(Loc);
-  Value *TripCountVal = expandCodeFor(ExitCount, CountTy, Loc);
+  Value *TripCountVal = expandCodeForImpl(ExitCount, CountTy, Loc, false);
 
   IntegerType *Ty =
       IntegerType::get(Loc->getContext(), SE.getTypeSizeInBits(ARTy));
   Type *ARExpandTy = DL.isNonIntegralPointerType(ARTy) ? ARTy : Ty;
 
-  Value *StepValue = expandCodeFor(Step, Ty, Loc);
-  Value *NegStepValue = expandCodeFor(SE.getNegativeSCEV(Step), Ty, Loc);
-  Value *StartValue = expandCodeFor(Start, ARExpandTy, Loc);
+  Value *StepValue = expandCodeForImpl(Step, Ty, Loc, false);
+  Value *NegStepValue =
+      expandCodeForImpl(SE.getNegativeSCEV(Step), Ty, Loc, false);
+  Value *StartValue = expandCodeForImpl(Start, ARExpandTy, Loc, false);
 
   ConstantInt *Zero =
       ConstantInt::get(Loc->getContext(), APInt::getNullValue(DstBits));
@@ -2459,6 +2513,25 @@ Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union,
   return Check;
 }
 
+Value *SCEVExpander::fixupLCSSAFormFor(Instruction *User, unsigned OpIdx) {
+  assert(PreserveLCSSA);
+  SmallVector<Instruction *, 1> ToUpdate;
+
+  auto *OpV = User->getOperand(OpIdx);
+  auto *OpI = dyn_cast<Instruction>(OpV);
+  if (!OpI)
+    return OpV;
+
+  Loop *DefLoop = SE.LI.getLoopFor(OpI->getParent());
+  Loop *UseLoop = SE.LI.getLoopFor(User->getParent());
+  if (!DefLoop || UseLoop == DefLoop || DefLoop->contains(UseLoop))
+    return OpV;
+
+  ToUpdate.push_back(OpI);
+  formLCSSAForInstructions(ToUpdate, SE.DT, SE.LI, &SE);
+  return User->getOperand(OpIdx);
+}
+
 namespace {
 // Search for a SCEV subexpression that is not safe to expand.  Any expression
 // that may expand to a !isSafeToSpeculativelyExecute value is unsafe, namely
index c146d15..6e83370 100644 (file)
@@ -265,7 +265,7 @@ TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderIsSafeToExpandAt) {
   Phi->addIncoming(Add, L);
 
   Builder.SetInsertPoint(Post);
-  Builder.CreateRetVoid();
+  Instruction *Ret = Builder.CreateRetVoid();
 
   ScalarEvolution SE = buildSE(*F);
   const SCEV *S = SE.getSCEV(Phi);
@@ -276,6 +276,11 @@ TEST_F(ScalarEvolutionExpanderTest, SCEVExpanderIsSafeToExpandAt) {
   EXPECT_FALSE(isSafeToExpandAt(AR, LPh->getTerminator(), SE));
   EXPECT_TRUE(isSafeToExpandAt(AR, L->getTerminator(), SE));
   EXPECT_TRUE(isSafeToExpandAt(AR, Post->getTerminator(), SE));
+
+  EXPECT_TRUE(LI->getLoopFor(L)->isLCSSAForm(*DT));
+  SCEVExpander Exp(SE, M.getDataLayout(), "expander");
+  Exp.expandCodeFor(SE.getSCEV(Add), nullptr, Ret);
+  EXPECT_TRUE(LI->getLoopFor(L)->isLCSSAForm(*DT));
 }
 
 // Check that SCEV expander does not use the nuw instruction