[DSE,MSSA] Eliminate stores by terminators (free,lifetime.end).
authorFlorian Hahn <flo@fhahn.com>
Wed, 8 Jul 2020 07:42:55 +0000 (08:42 +0100)
committerFlorian Hahn <flo@fhahn.com>
Wed, 8 Jul 2020 07:59:46 +0000 (08:59 +0100)
This patch adds support for eliminating stores by free & lifetime.end
calls. We can remove stores that are not read before calling a memory
terminator and we can eliminate all stores after a memory terminator
until we see a new lifetime.start. The second case seems to not really
trigger much in practice though.

Reviewers: dmgreen, rnk, efriedma, bryant, asbirlea, Tyker

Reviewed By: asbirlea

Differential Revision: https://reviews.llvm.org/D72410

llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
llvm/test/Transforms/DeadStoreElimination/MSSA/2016-07-17-UseAfterFree.ll
llvm/test/Transforms/DeadStoreElimination/MSSA/free.ll
llvm/test/Transforms/DeadStoreElimination/MSSA/lifetime.ll
llvm/test/Transforms/DeadStoreElimination/MSSA/memset-missing-debugloc.ll
llvm/test/Transforms/DeadStoreElimination/MSSA/multiblock-captures.ll
llvm/test/Transforms/DeadStoreElimination/MSSA/multiblock-malloc-free.ll
llvm/test/Transforms/DeadStoreElimination/MSSA/simple.ll

index dd8dc84..e58db03 100644 (file)
@@ -51,6 +51,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/IR/Value.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
@@ -73,6 +74,7 @@
 #include <utility>
 
 using namespace llvm;
+using namespace PatternMatch;
 
 #define DEBUG_TYPE "dse"
 
@@ -1533,7 +1535,7 @@ struct DSEState {
 
         auto *MD = dyn_cast_or_null<MemoryDef>(MA);
         if (MD && State.MemDefs.size() < MemorySSADefsPerBlockLimit &&
-            State.getLocForWriteEx(&I))
+            (State.getLocForWriteEx(&I) || State.isMemTerminatorInst(&I)))
           State.MemDefs.push_back(MD);
 
         // Track whether alloca and alloca-like objects are visible in the
@@ -1667,6 +1669,51 @@ struct DSEState {
     return true;
   }
 
+  /// If \p I is a memory  terminator like llvm.lifetime.end or free, return a
+  /// pair with the MemoryLocation terminated by \p I and a boolean flag
+  /// indicating whether \p I is a free-like call.
+  Optional<std::pair<MemoryLocation, bool>>
+  getLocForTerminator(Instruction *I) const {
+    uint64_t Len;
+    Value *Ptr;
+    if (match(I, m_Intrinsic<Intrinsic::lifetime_end>(m_ConstantInt(Len),
+                                                      m_Value(Ptr))))
+      return {std::make_pair(MemoryLocation(Ptr, Len), false)};
+
+    if (auto *CB = dyn_cast<CallBase>(I)) {
+      if (isFreeCall(I, &TLI))
+        return {std::make_pair(MemoryLocation(CB->getArgOperand(0)), true)};
+    }
+
+    return None;
+  }
+
+  /// Returns true if \p I is a memory terminator instruction like
+  /// llvm.lifetime.end or free.
+  bool isMemTerminatorInst(Instruction *I) const {
+    IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
+    return (II && II->getIntrinsicID() == Intrinsic::lifetime_end) ||
+           isFreeCall(I, &TLI);
+  }
+
+  /// Returns true if \p MaybeTerm is a memory terminator for the same
+  /// underlying object as \p DefLoc.
+  bool isMemTerminator(MemoryLocation DefLoc, Instruction *MaybeTerm) const {
+    Optional<std::pair<MemoryLocation, bool>> MaybeTermLoc =
+        getLocForTerminator(MaybeTerm);
+
+    if (!MaybeTermLoc)
+      return false;
+
+    // If the terminator is a free-like call, all accesses to the underlying
+    // object can be considered terminated.
+    if (MaybeTermLoc->second) {
+      DataLayout DL = MaybeTerm->getParent()->getModule()->getDataLayout();
+      DefLoc = MemoryLocation(GetUnderlyingObject(DefLoc.Ptr, DL));
+    }
+    return AA.isMustAlias(MaybeTermLoc->first, DefLoc);
+  }
+
   // Returns true if \p Use may read from \p DefLoc.
   bool isReadClobber(MemoryLocation DefLoc, Instruction *UseInst) const {
     if (!UseInst->mayReadFromMemory())
@@ -1772,6 +1819,11 @@ struct DSEState {
         continue;
       }
 
+      // A memory terminator kills all preceeding MemoryDefs and all succeeding
+      // MemoryAccesses. We do not have to check it's users.
+      if (isMemTerminator(DefLoc, UseInst))
+        continue;
+
       // Uses which may read the original MemoryDef mean we cannot eliminate the
       // original MD. Stop walk.
       if (isReadClobber(DefLoc, UseInst)) {
@@ -2059,6 +2111,12 @@ bool eliminateDeadStoresMemorySSA(Function &F, AliasAnalysis &AA,
     Instruction *SI = KillingDef->getMemoryInst();
 
     auto MaybeSILoc = State.getLocForWriteEx(SI);
+    if (State.isMemTerminatorInst(SI))
+      MaybeSILoc = State.getLocForTerminator(SI).map(
+          [](const std::pair<MemoryLocation, bool> &P) { return P.first; });
+    else
+      MaybeSILoc = State.getLocForWriteEx(SI);
+
     if (!MaybeSILoc) {
       LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
                         << *SI << "\n");
@@ -2165,43 +2223,55 @@ bool eliminateDeadStoresMemorySSA(Function &F, AliasAnalysis &AA,
         continue;
 
       MemoryLocation NILoc = *State.getLocForWriteEx(NI);
-      // Check if NI overwrites SI.
-      int64_t InstWriteOffset, DepWriteOffset;
-      auto Iter = State.IOLs.insert(
-          std::make_pair<BasicBlock *, InstOverlapIntervalsTy>(
-              NI->getParent(), InstOverlapIntervalsTy()));
-      auto &IOL = Iter.first->second;
-      OverwriteResult OR = isOverwrite(SILoc, NILoc, DL, TLI, DepWriteOffset,
-                                       InstWriteOffset, NI, IOL, AA, &F);
-
-      if (EnablePartialStoreMerging && OR == OW_PartialEarlierWithFullLater) {
-        auto *Earlier = dyn_cast<StoreInst>(NI);
-        auto *Later = dyn_cast<StoreInst>(SI);
-        if (Constant *Merged = tryToMergePartialOverlappingStores(
-                Earlier, Later, InstWriteOffset, DepWriteOffset, DL, &AA,
-                &DT)) {
-
-          // Update stored value of earlier store to merged constant.
-          Earlier->setOperand(0, Merged);
-          ++NumModifiedStores;
-          MadeChange = true;
-
-          // Remove later store and remove any outstanding overlap intervals for
-          // the updated store.
-          State.deleteDeadInstruction(Later);
-          auto I = State.IOLs.find(Earlier->getParent());
-          if (I != State.IOLs.end())
-            I->second.erase(Earlier);
-          break;
-        }
-      }
 
-      if (OR == OW_Complete) {
+      if (State.isMemTerminatorInst(SI)) {
+        const Value *NIUnd = GetUnderlyingObject(NILoc.Ptr, DL);
+        if (!SILocUnd || SILocUnd != NIUnd)
+          continue;
         LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n  DEAD: " << *NI
                           << "\n  KILLER: " << *SI << '\n');
         State.deleteDeadInstruction(NI);
         ++NumFastStores;
         MadeChange = true;
+      } else {
+        // Check if NI overwrites SI.
+        int64_t InstWriteOffset, DepWriteOffset;
+        auto Iter = State.IOLs.insert(
+            std::make_pair<BasicBlock *, InstOverlapIntervalsTy>(
+                NI->getParent(), InstOverlapIntervalsTy()));
+        auto &IOL = Iter.first->second;
+        OverwriteResult OR = isOverwrite(SILoc, NILoc, DL, TLI, DepWriteOffset,
+                                         InstWriteOffset, NI, IOL, AA, &F);
+
+        if (EnablePartialStoreMerging && OR == OW_PartialEarlierWithFullLater) {
+          auto *Earlier = dyn_cast<StoreInst>(NI);
+          auto *Later = dyn_cast<StoreInst>(SI);
+          if (Constant *Merged = tryToMergePartialOverlappingStores(
+                  Earlier, Later, InstWriteOffset, DepWriteOffset, DL, &AA,
+                  &DT)) {
+
+            // Update stored value of earlier store to merged constant.
+            Earlier->setOperand(0, Merged);
+            ++NumModifiedStores;
+            MadeChange = true;
+
+            // Remove later store and remove any outstanding overlap intervals
+            // for the updated store.
+            State.deleteDeadInstruction(Later);
+            auto I = State.IOLs.find(Earlier->getParent());
+            if (I != State.IOLs.end())
+              I->second.erase(Earlier);
+            break;
+          }
+        }
+
+        if (OR == OW_Complete) {
+          LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n  DEAD: " << *NI
+                            << "\n  KILLER: " << *SI << '\n');
+          State.deleteDeadInstruction(NI);
+          ++NumFastStores;
+          MadeChange = true;
+        }
       }
     }
   }
index f5d7e25..85a749f 100644 (file)
@@ -1,5 +1,4 @@
-; XFAIL: *
-; RUN: opt < %s -basic-aa -dse-enable-dse-memoryssa  -S -enable-dse-partial-overwrite-tracking | FileCheck %s
+; RUN: opt < %s -basic-aa -dse -enable-dse-memoryssa  -S -enable-dse-partial-overwrite-tracking | FileCheck %s
 ; PR28588
 
 target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
index 81e64c8..13cfb70 100644 (file)
@@ -1,5 +1,3 @@
-; XFAIL: *
-
 ; RUN: opt < %s -basic-aa -dse -enable-dse-memoryssa -S | FileCheck %s
 
 target datalayout = "e-p:64:64:64"
index 222c293..29ff772 100644 (file)
@@ -1,5 +1,3 @@
-; XFAIL: *
-
 ; RUN: opt -S -basic-aa -dse -enable-dse-memoryssa < %s | FileCheck %s
 
 target datalayout = "E-p:64:64:64-a0:0:8-f32:32:32-f64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-v64:64:64-v128:128:128"
@@ -35,5 +33,3 @@ define void @test2(i32* %P) {
 ; CHECK: lifetime.end
   ret void
 }
-
-
index 80db7f5..c28f0cc 100644 (file)
@@ -1,3 +1,4 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; Test that the getelementptr generated when the dse pass determines that
 ; a memset can be shortened has the debugloc carried over from the memset.
 
index c7949e7..e287139 100644 (file)
@@ -207,6 +207,7 @@ exit:
   call void @capture(i8* %m)
   ret i8* %m
 }
+
 ; Stores to stack objects can be eliminated if they are not captured inside the function.
 define void @test_alloca_nocapture_1() {
 ; CHECK-LABEL: @test_alloca_nocapture_1(
index c8b951e..04cdae2 100644 (file)
@@ -17,13 +17,12 @@ declare void @free(i8* nocapture) #2
 define void @test16(i32* noalias %P) {
 ; CHECK-LABEL: @test16(
 ; CHECK-NEXT:    [[P2:%.*]] = bitcast i32* [[P:%.*]] to i8*
-; CHECK-NEXT:    store i32 1, i32* [[P]]
 ; CHECK-NEXT:    br i1 true, label [[BB1:%.*]], label [[BB3:%.*]]
 ; CHECK:       bb1:
-; CHECK-NEXT:    store i32 1, i32* [[P]]
 ; CHECK-NEXT:    br label [[BB3]]
 ; CHECK:       bb3:
 ; CHECK-NEXT:    call void @free(i8* [[P2]])
+; CHECK-NEXT:    store i32 1, i32* [[P]]
 ; CHECK-NEXT:    ret void
 ;
   %P2 = bitcast i32* %P to i8*
@@ -34,6 +33,7 @@ bb1:
   br label %bb3
 bb3:
   call void @free(i8* %P2)
+  store i32 1, i32* %P
   ret void
 }
 
@@ -41,11 +41,9 @@ bb3:
 define void @test17(i32* noalias %P) {
 ; CHECK-LABEL: @test17(
 ; CHECK-NEXT:    [[P2:%.*]] = bitcast i32* [[P:%.*]] to i8*
-; CHECK-NEXT:    store i32 1, i32* [[P]]
 ; CHECK-NEXT:    br i1 true, label [[BB1:%.*]], label [[BB3:%.*]]
 ; CHECK:       bb1:
 ; CHECK-NEXT:    call void @unknown_func()
-; CHECK-NEXT:    store i32 1, i32* [[P]]
 ; CHECK-NEXT:    br label [[BB3]]
 ; CHECK:       bb3:
 ; CHECK-NEXT:    call void @free(i8* [[P2]])
@@ -63,6 +61,30 @@ bb3:
   ret void
 }
 
+define void @test17_read_after_free(i32* noalias %P) {
+; CHECK-LABEL: @test17_read_after_free(
+; CHECK-NEXT:    [[P2:%.*]] = bitcast i32* [[P:%.*]] to i8*
+; CHECK-NEXT:    br i1 true, label [[BB1:%.*]], label [[BB3:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    br label [[BB3]]
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @free(i8* [[P2]])
+; CHECK-NEXT:    [[LV:%.*]] = load i8, i8* [[P2]]
+; CHECK-NEXT:    ret void
+;
+  %P2 = bitcast i32* %P to i8*
+  store i32 1, i32* %P
+  br i1 true, label %bb1, label %bb3
+bb1:
+  store i32 1, i32* %P
+  br label %bb3
+bb3:
+  call void @free(i8* %P2)
+  %lv = load i8, i8* %P2
+  ret void
+}
+
+
 define void @test6(i32* noalias %P) {
 ; CHECK-LABEL: @test6(
 ; CHECK-NEXT:    br i1 true, label [[BB1:%.*]], label [[BB2:%.*]]
index 8411c1d..ef785f1 100644 (file)
@@ -625,7 +625,6 @@ define void @test41(i32* noalias %P) {
 ; CHECK-NEXT:    [[P2:%.*]] = bitcast i32* [[P:%.*]] to i8*
 ; CHECK-NEXT:    store i32 1, i32* [[P]], align 4
 ; CHECK-NEXT:    call void @unknown_func()
-; CHECK-NEXT:    store i32 2, i32* [[P]], align 4
 ; CHECK-NEXT:    call void @free(i8* [[P2]])
 ; CHECK-NEXT:    ret void
 ;