[MemCpyOpt] Use BatchAA when processing one instruction (NFCI)
authorNikita Popov <npopov@redhat.com>
Tue, 6 Dec 2022 09:14:00 +0000 (10:14 +0100)
committerNikita Popov <npopov@redhat.com>
Tue, 6 Dec 2022 09:16:39 +0000 (10:16 +0100)
While we can't use a single BatchAA instance for the entire
MemCpyOpt run without further justification, we can use BatchAA
while performing the queries related to a single instruction
(these will first perform some AA-based checks, and then modify
the IR only afterwards).

llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h
llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp

index 8103b0a..587b782 100644 (file)
@@ -20,6 +20,7 @@
 namespace llvm {
 
 class AAResults;
+class BatchAAResults;
 class AssumptionCache;
 class CallBase;
 class CallInst;
@@ -61,10 +62,14 @@ private:
   bool processMemMove(MemMoveInst *M);
   bool performCallSlotOptzn(Instruction *cpyLoad, Instruction *cpyStore,
                             Value *cpyDst, Value *cpySrc, TypeSize cpyLen,
-                            Align cpyAlign, std::function<CallInst *()> GetC);
-  bool processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep);
-  bool processMemSetMemCpyDependence(MemCpyInst *MemCpy, MemSetInst *MemSet);
-  bool performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, MemSetInst *MemSet);
+                            Align cpyAlign, BatchAAResults &BAA,
+                            std::function<CallInst *()> GetC);
+  bool processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep,
+                                     BatchAAResults &BAA);
+  bool processMemSetMemCpyDependence(MemCpyInst *MemCpy, MemSetInst *MemSet,
+                                     BatchAAResults &BAA);
+  bool performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, MemSetInst *MemSet,
+                                  BatchAAResults &BAA);
   bool processByValArgument(CallBase &CB, unsigned ArgNo);
   Instruction *tryMergingIntoMemset(Instruction *I, Value *StartPtr,
                                     Value *ByteVal);
index 7ef00ea..8174761 100644 (file)
@@ -335,7 +335,7 @@ void MemCpyOptPass::eraseInstruction(Instruction *I) {
 // Start and End must be in the same block.
 // If SkippedLifetimeStart is provided, skip over one clobbering lifetime.start
 // intrinsic and store it inside SkippedLifetimeStart.
-static bool accessedBetween(AliasAnalysis &AA, MemoryLocation Loc,
+static bool accessedBetween(BatchAAResults &AA, MemoryLocation Loc,
                             const MemoryUseOrDef *Start,
                             const MemoryUseOrDef *End,
                             Instruction **SkippedLifetimeStart = nullptr) {
@@ -359,7 +359,7 @@ static bool accessedBetween(AliasAnalysis &AA, MemoryLocation Loc,
 
 // Check for mod of Loc between Start and End, excluding both boundaries.
 // Start and End can be in different blocks.
-static bool writtenBetween(MemorySSA *MSSA, AliasAnalysis &AA,
+static bool writtenBetween(MemorySSA *MSSA, BatchAAResults &AA,
                            MemoryLocation Loc, const MemoryUseOrDef *Start,
                            const MemoryUseOrDef *End) {
   if (isa<MemoryUse>(End)) {
@@ -380,7 +380,7 @@ static bool writtenBetween(MemorySSA *MSSA, AliasAnalysis &AA,
 
   // TODO: Only walk until we hit Start.
   MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess(
-      End->getDefiningAccess(), Loc);
+      End->getDefiningAccess(), Loc, AA);
   return !MSSA->dominates(Clobber, Start);
 }
 
@@ -778,11 +778,12 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {
       // Detect cases where we're performing call slot forwarding, but
       // happen to be using a load-store pair to implement it, rather than
       // a memcpy.
+      BatchAAResults BAA(*AA);
       auto GetCall = [&]() -> CallInst * {
         // We defer this expensive clobber walk until the cheap checks
         // have been done on the source inside performCallSlotOptzn.
         if (auto *LoadClobber = dyn_cast<MemoryUseOrDef>(
-              MSSA->getWalker()->getClobberingMemoryAccess(LI)))
+                MSSA->getWalker()->getClobberingMemoryAccess(LI, BAA)))
           return dyn_cast_or_null<CallInst>(LoadClobber->getMemoryInst());
         return nullptr;
       };
@@ -791,7 +792,7 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {
           LI, SI, SI->getPointerOperand()->stripPointerCasts(),
           LI->getPointerOperand()->stripPointerCasts(),
           DL.getTypeStoreSize(SI->getOperand(0)->getType()),
-          std::min(SI->getAlign(), LI->getAlign()), GetCall);
+          std::min(SI->getAlign(), LI->getAlign()), BAA, GetCall);
       if (changed) {
         eraseInstruction(SI);
         eraseInstruction(LI);
@@ -872,7 +873,7 @@ bool MemCpyOptPass::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) {
 bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
                                          Instruction *cpyStore, Value *cpyDest,
                                          Value *cpySrc, TypeSize cpySize,
-                                         Align cpyAlign,
+                                         Align cpyAlign, BatchAAResults &BAA,
                                          std::function<CallInst *()> GetC) {
   // The general transformation to keep in mind is
   //
@@ -930,7 +931,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
   // Check that nothing touches the dest of the copy between
   // the call and the store/memcpy.
   Instruction *SkippedLifetimeStart = nullptr;
-  if (accessedBetween(*AA, DestLoc, MSSA->getMemoryAccess(C),
+  if (accessedBetween(BAA, DestLoc, MSSA->getMemoryAccess(C),
                       MSSA->getMemoryAccess(cpyStore), &SkippedLifetimeStart)) {
     LLVM_DEBUG(dbgs() << "Call Slot: Dest pointer modified after call\n");
     return false;
@@ -1058,7 +1059,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
       // pointer (we have already any direct mod/refs in the loop above).
       // Also bail if we hit a terminator, as we don't want to scan into other
       // blocks.
-      if (isModOrRefSet(AA->getModRefInfo(&I, SrcLoc)) || I.isTerminator())
+      if (isModOrRefSet(BAA.getModRefInfo(&I, SrcLoc)) || I.isTerminator())
         return false;
     }
   }
@@ -1079,10 +1080,11 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
   // unexpected manner, for example via a global, which we deduce from
   // the use analysis, we also need to know that it does not sneakily
   // access dest.  We rely on AA to figure this out for us.
-  ModRefInfo MR = AA->getModRefInfo(C, cpyDest, LocationSize::precise(srcSize));
+  MemoryLocation DestWithSrcSize(cpyDest, LocationSize::precise(srcSize));
+  ModRefInfo MR = BAA.getModRefInfo(C, DestWithSrcSize);
   // If necessary, perform additional analysis.
   if (isModOrRefSet(MR))
-    MR = AA->callCapturesBefore(C, cpyDest, LocationSize::precise(srcSize), DT);
+    MR = BAA.callCapturesBefore(C, DestWithSrcSize, DT);
   if (isModOrRefSet(MR))
     return false;
 
@@ -1146,7 +1148,8 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
 /// We've found that the (upward scanning) memory dependence of memcpy 'M' is
 /// the memcpy 'MDep'. Try to simplify M to copy from MDep's input if we can.
 bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M,
-                                                  MemCpyInst *MDep) {
+                                                  MemCpyInst *MDep,
+                                                  BatchAAResults &BAA) {
   // We can only transforms memcpy's where the dest of one is the source of the
   // other.
   if (M->getSource() != MDep->getDest() || MDep->isVolatile())
@@ -1180,7 +1183,7 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M,
   // then we could still perform the xform by moving M up to the first memcpy.
   // TODO: It would be sufficient to check the MDep source up to the memcpy
   // size of M, rather than MDep.
-  if (writtenBetween(MSSA, *AA, MemoryLocation::getForSource(MDep),
+  if (writtenBetween(MSSA, BAA, MemoryLocation::getForSource(MDep),
                      MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(M)))
     return false;
 
@@ -1190,7 +1193,7 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M,
   // still want to eliminate the intermediate value, but we have to generate a
   // memmove instead of memcpy.
   bool UseMemMove = false;
-  if (isModSet(AA->getModRefInfo(M, MemoryLocation::getForSource(MDep))))
+  if (isModSet(BAA.getModRefInfo(M, MemoryLocation::getForSource(MDep))))
     UseMemMove = true;
 
   // If all checks passed, then we can transform M.
@@ -1244,20 +1247,21 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M,
 ///   memset(dst + src_size, c, dst_size <= src_size ? 0 : dst_size - src_size);
 /// \endcode
 bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy,
-                                                  MemSetInst *MemSet) {
+                                                  MemSetInst *MemSet,
+                                                  BatchAAResults &BAA) {
   // We can only transform memset/memcpy with the same destination.
-  if (!AA->isMustAlias(MemSet->getDest(), MemCpy->getDest()))
+  if (!BAA.isMustAlias(MemSet->getDest(), MemCpy->getDest()))
     return false;
 
   // Check that src and dst of the memcpy aren't the same. While memcpy
   // operands cannot partially overlap, exact equality is allowed.
-  if (isModSet(AA->getModRefInfo(MemCpy, MemoryLocation::getForSource(MemCpy))))
+  if (isModSet(BAA.getModRefInfo(MemCpy, MemoryLocation::getForSource(MemCpy))))
     return false;
 
   // We know that dst up to src_size is not written. We now need to make sure
   // that dst up to dst_size is not accessed. (If we did not move the memset,
   // checking for reads would be sufficient.)
-  if (accessedBetween(*AA, MemoryLocation::getForDest(MemSet),
+  if (accessedBetween(BAA, MemoryLocation::getForDest(MemSet),
                       MSSA->getMemoryAccess(MemSet),
                       MSSA->getMemoryAccess(MemCpy)))
     return false;
@@ -1327,7 +1331,7 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy,
 
 /// Determine whether the instruction has undefined content for the given Size,
 /// either because it was freshly alloca'd or started its lifetime.
-static bool hasUndefContents(MemorySSA *MSSA, AliasAnalysis *AA, Value *V,
+static bool hasUndefContents(MemorySSA *MSSA, BatchAAResults &AA, Value *V,
                              MemoryDef *Def, Value *Size) {
   if (MSSA->isLiveOnEntryDef(Def))
     return isa<AllocaInst>(getUnderlyingObject(V));
@@ -1337,7 +1341,7 @@ static bool hasUndefContents(MemorySSA *MSSA, AliasAnalysis *AA, Value *V,
       auto *LTSize = cast<ConstantInt>(II->getArgOperand(0));
 
       if (auto *CSize = dyn_cast<ConstantInt>(Size)) {
-        if (AA->isMustAlias(V, II->getArgOperand(1)) &&
+        if (AA.isMustAlias(V, II->getArgOperand(1)) &&
             LTSize->getZExtValue() >= CSize->getZExtValue())
           return true;
       }
@@ -1374,10 +1378,11 @@ static bool hasUndefContents(MemorySSA *MSSA, AliasAnalysis *AA, Value *V,
 /// \endcode
 /// When dst2_size <= dst1_size.
 bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy,
-                                               MemSetInst *MemSet) {
+                                               MemSetInst *MemSet,
+                                               BatchAAResults &BAA) {
   // Make sure that memcpy(..., memset(...), ...), that is we are memsetting and
   // memcpying from the same address. Otherwise it is hard to reason about.
-  if (!AA->isMustAlias(MemSet->getRawDest(), MemCpy->getRawSource()))
+  if (!BAA.isMustAlias(MemSet->getRawDest(), MemCpy->getRawSource()))
     return false;
 
   Value *MemSetSize = MemSet->getLength();
@@ -1405,9 +1410,9 @@ bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy,
       bool CanReduceSize = false;
       MemoryUseOrDef *MemSetAccess = MSSA->getMemoryAccess(MemSet);
       MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess(
-          MemSetAccess->getDefiningAccess(), MemCpyLoc);
+          MemSetAccess->getDefiningAccess(), MemCpyLoc, BAA);
       if (auto *MD = dyn_cast<MemoryDef>(Clobber))
-        if (hasUndefContents(MSSA, AA, MemCpy->getSource(), MD, CopySize))
+        if (hasUndefContents(MSSA, BAA, MemCpy->getSource(), MD, CopySize))
           CanReduceSize = true;
 
       if (!CanReduceSize)
@@ -1464,12 +1469,13 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
         return true;
       }
 
+  BatchAAResults BAA(*AA);
   MemoryUseOrDef *MA = MSSA->getMemoryAccess(M);
   // FIXME: Not using getClobberingMemoryAccess() here due to PR54682.
   MemoryAccess *AnyClobber = MA->getDefiningAccess();
   MemoryLocation DestLoc = MemoryLocation::getForDest(M);
   const MemoryAccess *DestClobber =
-      MSSA->getWalker()->getClobberingMemoryAccess(AnyClobber, DestLoc);
+      MSSA->getWalker()->getClobberingMemoryAccess(AnyClobber, DestLoc, BAA);
 
   // Try to turn a partially redundant memset + memcpy into
   // memcpy + smaller memset.  We don't need the memcpy size for this.
@@ -1478,11 +1484,11 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
   if (auto *MD = dyn_cast<MemoryDef>(DestClobber))
     if (auto *MDep = dyn_cast_or_null<MemSetInst>(MD->getMemoryInst()))
       if (DestClobber->getBlock() == M->getParent())
-        if (processMemSetMemCpyDependence(M, MDep))
+        if (processMemSetMemCpyDependence(M, MDep, BAA))
           return true;
 
   MemoryAccess *SrcClobber = MSSA->getWalker()->getClobberingMemoryAccess(
-      AnyClobber, MemoryLocation::getForSource(M));
+      AnyClobber, MemoryLocation::getForSource(M), BAA);
 
   // There are four possible optimizations we can do for memcpy:
   //   a) memcpy-memcpy xform which exposes redundance for DSE.
@@ -1499,10 +1505,10 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
           // of conservatively taking the minimum?
           Align Alignment = std::min(M->getDestAlign().valueOrOne(),
                                      M->getSourceAlign().valueOrOne());
-          if (performCallSlotOptzn(
-                  M, M, M->getDest(), M->getSource(),
-                  TypeSize::getFixed(CopySize->getZExtValue()), Alignment,
-                  [C]() -> CallInst * { return C; })) {
+          if (performCallSlotOptzn(M, M, M->getDest(), M->getSource(),
+                                   TypeSize::getFixed(CopySize->getZExtValue()),
+                                   Alignment, BAA,
+                                   [C]() -> CallInst * { return C; })) {
             LLVM_DEBUG(dbgs() << "Performed call slot optimization:\n"
                               << "    call: " << *C << "\n"
                               << "    memcpy: " << *M << "\n");
@@ -1513,9 +1519,9 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
         }
       }
       if (auto *MDep = dyn_cast<MemCpyInst>(MI))
-        return processMemCpyMemCpyDependence(M, MDep);
+        return processMemCpyMemCpyDependence(M, MDep, BAA);
       if (auto *MDep = dyn_cast<MemSetInst>(MI)) {
-        if (performMemCpyToMemSetOptzn(M, MDep)) {
+        if (performMemCpyToMemSetOptzn(M, MDep, BAA)) {
           LLVM_DEBUG(dbgs() << "Converted memcpy to memset\n");
           eraseInstruction(M);
           ++NumCpyToSet;
@@ -1524,7 +1530,7 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
       }
     }
 
-    if (hasUndefContents(MSSA, AA, M->getSource(), MD, M->getLength())) {
+    if (hasUndefContents(MSSA, BAA, M->getSource(), MD, M->getLength())) {
       LLVM_DEBUG(dbgs() << "Removed memcpy from undef\n");
       eraseInstruction(M);
       ++NumMemCpyInstr;
@@ -1571,8 +1577,9 @@ bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) {
   if (!CallAccess)
     return false;
   MemCpyInst *MDep = nullptr;
+  BatchAAResults BAA(*AA);
   MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess(
-      CallAccess->getDefiningAccess(), Loc);
+      CallAccess->getDefiningAccess(), Loc, BAA);
   if (auto *MD = dyn_cast<MemoryDef>(Clobber))
     MDep = dyn_cast_or_null<MemCpyInst>(MD->getMemoryInst());
 
@@ -1613,7 +1620,7 @@ bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) {
   //    *b = 42;
   //    foo(*a)
   // It would be invalid to transform the second memcpy into foo(*b).
-  if (writtenBetween(MSSA, *AA, MemoryLocation::getForSource(MDep),
+  if (writtenBetween(MSSA, BAA, MemoryLocation::getForSource(MDep),
                      MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(&CB)))
     return false;