[DeadStoreElimination] Shorten beginning of memset overwritten by later stores
authorJun Bum Lim <junbuml@codeaurora.org>
Fri, 22 Apr 2016 19:51:29 +0000 (19:51 +0000)
committerJun Bum Lim <junbuml@codeaurora.org>
Fri, 22 Apr 2016 19:51:29 +0000 (19:51 +0000)
Summary: This change will shorten memset if the beginning of memset is overwritten by later stores.

Reviewers: hfinkel, eeckstein, dberlin, mcrosier

Subscribers: mgrang, mcrosier, llvm-commits

Differential Revision: http://reviews.llvm.org/D18906

llvm-svn: 267197

llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
llvm/test/Transforms/DeadStoreElimination/OverwriteStoreBegin.ll [new file with mode: 0644]

index 61bbdf0..39026c6 100644 (file)
@@ -275,9 +275,9 @@ static bool isRemovable(Instruction *I) {
 }
 
 
-/// isShortenable - Returns true if this instruction can be safely shortened in
+/// Returns true if the end of this instruction can be safely shortened in
 /// length.
-static bool isShortenable(Instruction *I) {
+static bool isShortenableAtTheEnd(Instruction *I) {
   // Don't shorten stores for now
   if (isa<StoreInst>(I))
     return false;
@@ -288,6 +288,7 @@ static bool isShortenable(Instruction *I) {
       case Intrinsic::memset:
       case Intrinsic::memcpy:
         // Do shorten memory intrinsics.
+        // FIXME: Add memmove if it's also safe to transform.
         return true;
     }
   }
@@ -297,6 +298,15 @@ static bool isShortenable(Instruction *I) {
   return false;
 }
 
+/// Returns true if the beginning of this instruction can be safely shortened
+/// in length.
+static bool isShortenableAtTheBeginning(Instruction *I) {
+  // FIXME: Handle only memset for now. Supporting memcpy/memmove should be
+  // easily done by offsetting the source address.
+  IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
+  return II && II->getIntrinsicID() == Intrinsic::memset;
+}
+
 /// getStoredPointerOperand - Return the pointer that is being written to.
 static Value *getStoredPointerOperand(Instruction *I) {
   if (StoreInst *SI = dyn_cast<StoreInst>(I))
@@ -327,18 +337,19 @@ static uint64_t getPointerSize(const Value *V, const DataLayout &DL,
 }
 
 namespace {
-  enum OverwriteResult
-  {
-    OverwriteComplete,
-    OverwriteEnd,
-    OverwriteUnknown
-  };
+enum OverwriteResult {
+  OverwriteBegin,
+  OverwriteComplete,
+  OverwriteEnd,
+  OverwriteUnknown
+};
 }
 
-/// isOverwrite - Return 'OverwriteComplete' if a store to the 'Later' location
-/// completely overwrites a store to the 'Earlier' location.
-/// 'OverwriteEnd' if the end of the 'Earlier' location is completely
-/// overwritten by 'Later', or 'OverwriteUnknown' if nothing can be determined
+/// Return 'OverwriteComplete' if a store to the 'Later' location completely
+/// overwrites a store to the 'Earlier' location, 'OverwriteEnd' if the end of
+/// the 'Earlier' location is completely overwritten by 'Later',
+/// 'OverwriteBegin' if the beginning of the 'Earlier' location is overwritten
+/// by 'Later', or 'OverwriteUnknown' if nothing can be determined.
 static OverwriteResult isOverwrite(const MemoryLocation &Later,
                                    const MemoryLocation &Earlier,
                                    const DataLayout &DL,
@@ -416,8 +427,8 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,
       uint64_t(EarlierOff - LaterOff) + Earlier.Size <= Later.Size)
     return OverwriteComplete;
 
-  // The other interesting case is if the later store overwrites the end of
-  // the earlier store
+  // Another interesting case is if the later store overwrites the end of the
+  // earlier store.
   //
   //      |--earlier--|
   //                |--   later   --|
@@ -429,6 +440,20 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later,
       int64_t(LaterOff + Later.Size) >= int64_t(EarlierOff + Earlier.Size))
     return OverwriteEnd;
 
+  // Finally, we also need to check if the later store overwrites the beginning
+  // of the earlier store.
+  //
+  //                |--earlier--|
+  //      |--   later   --|
+  //
+  // In this case we may want to move the destination address and trim the size
+  // of earlier to avoid generating writes to addresses which will definitely
+  // be overwritten later.
+  if (LaterOff <= EarlierOff && int64_t(LaterOff + Later.Size) > EarlierOff) {
+    assert (int64_t(LaterOff + Later.Size) < int64_t(EarlierOff + Earlier.Size)
+            && "Expect to be handled as OverwriteComplete" );
+    return OverwriteBegin;
+  }
   // Otherwise, they don't completely overlap.
   return OverwriteUnknown;
 }
@@ -603,29 +628,49 @@ bool DSE::runOnBasicBlock(BasicBlock &BB) {
           if (BBI != BB.begin())
             --BBI;
           break;
-        } else if (OR == OverwriteEnd && isShortenable(DepWrite)) {
+        } else if ((OR == OverwriteEnd && isShortenableAtTheEnd(DepWrite)) ||
+                   ((OR == OverwriteBegin &&
+                     isShortenableAtTheBeginning(DepWrite)))) {
           // TODO: base this on the target vector size so that if the earlier
           // store was too small to get vector writes anyway then its likely
           // a good idea to shorten it
           // Power of 2 vector writes are probably always a bad idea to optimize
           // as any store/memset/memcpy is likely using vector instructions so
           // shortening it to not vector size is likely to be slower
-          MemIntrinsicDepIntrinsic = cast<MemIntrinsic>(DepWrite);
+          MemIntrinsic *DepIntrinsic = cast<MemIntrinsic>(DepWrite);
           unsigned DepWriteAlign = DepIntrinsic->getAlignment();
-          if (llvm::isPowerOf2_64(InstWriteOffset) ||
+          bool IsOverwriteEnd = (OR == OverwriteEnd);
+          if (!IsOverwriteEnd)
+            InstWriteOffset = int64_t(InstWriteOffset + Loc.Size);
+
+          if ((llvm::isPowerOf2_64(InstWriteOffset) &&
+               DepWriteAlign <= InstWriteOffset) ||
               ((DepWriteAlign != 0) && InstWriteOffset % DepWriteAlign == 0)) {
 
-            DEBUG(dbgs() << "DSE: Remove Dead Store:\n  OW END: "
-                  << *DepWrite << "\n  KILLER (offset "
-                  << InstWriteOffset << ", "
-                  << DepLoc.Size << ")"
-                  << *Inst << '\n');
+            DEBUG(dbgs() << "DSE: Remove Dead Store:\n  OW "
+                         << (IsOverwriteEnd ? "END" : "BEGIN") << ": "
+                         << *DepWrite << "\n  KILLER (offset "
+                         << InstWriteOffset << ", " << DepLoc.Size << ")"
+                         << *Inst << '\n');
 
-            Value* DepWriteLength = DepIntrinsic->getLength();
-            Value* TrimmedLength = ConstantInt::get(DepWriteLength->getType(),
-                                                    InstWriteOffset -
-                                                    DepWriteOffset);
+            int64_t NewLength =
+                IsOverwriteEnd
+                    ? InstWriteOffset - DepWriteOffset
+                    : DepLoc.Size - (InstWriteOffset - DepWriteOffset);
+
+            Value *DepWriteLength = DepIntrinsic->getLength();
+            Value *TrimmedLength =
+                ConstantInt::get(DepWriteLength->getType(), NewLength);
             DepIntrinsic->setLength(TrimmedLength);
+
+            if (!IsOverwriteEnd) {
+              int64_t OffsetMoved = (InstWriteOffset - DepWriteOffset);
+              Value *Indices[1] = {
+                  ConstantInt::get(DepWriteLength->getType(), OffsetMoved)};
+              GetElementPtrInst *NewDestGEP = GetElementPtrInst::CreateInBounds(
+                  DepIntrinsic->getRawDest(), Indices, "", DepWrite);
+              DepIntrinsic->setDest(NewDestGEP);
+            }
             MadeChange = true;
           }
         }
diff --git a/llvm/test/Transforms/DeadStoreElimination/OverwriteStoreBegin.ll b/llvm/test/Transforms/DeadStoreElimination/OverwriteStoreBegin.ll
new file mode 100644 (file)
index 0000000..0bcd851
--- /dev/null
@@ -0,0 +1,90 @@
+; RUN: opt < %s -basicaa -dse -S | FileCheck %s
+
+define void @write4to7(i32* nocapture %p) {
+; CHECK-LABEL: @write4to7(
+entry:
+  %arrayidx0 = getelementptr inbounds i32, i32* %p, i64 1
+  %p3 = bitcast i32* %arrayidx0 to i8*
+; CHECK: [[GEP:%[0-9]+]] = getelementptr inbounds i8, i8* %p3, i64 4
+; CHECK: call void @llvm.memset.p0i8.i64(i8* [[GEP]], i8 0, i64 24, i32 4, i1 false)
+  call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 28, i32 4, i1 false)
+  %arrayidx1 = getelementptr inbounds i32, i32* %p, i64 1
+  store i32 1, i32* %arrayidx1, align 4
+  ret void
+}
+
+define void @write0to3(i32* nocapture %p) {
+; CHECK-LABEL: @write0to3(
+entry:
+  %p3 = bitcast i32* %p to i8*
+; CHECK: [[GEP:%[0-9]+]] = getelementptr inbounds i8, i8* %p3, i64 4
+; CHECK: call void @llvm.memset.p0i8.i64(i8* [[GEP]], i8 0, i64 24, i32 4, i1 false)
+  call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 28, i32 4, i1 false)
+  store i32 1, i32* %p, align 4
+  ret void
+}
+
+define void @write0to7(i32* nocapture %p) {
+; CHECK-LABEL: @write0to7(
+entry:
+  %p3 = bitcast i32* %p to i8*
+; CHECK: [[GEP:%[0-9]+]] = getelementptr inbounds i8, i8* %p3, i64 8
+; CHECK: call void @llvm.memset.p0i8.i64(i8* [[GEP]], i8 0, i64 24, i32 4, i1 false)
+  call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 32, i32 4, i1 false)
+  %p4 = bitcast i32* %p to i64*
+  store i64 1, i64* %p4, align 8
+  ret void
+}
+
+define void @write0to7_2(i32* nocapture %p) {
+; CHECK-LABEL: @write0to7_2(
+entry:
+  %arrayidx0 = getelementptr inbounds i32, i32* %p, i64 1
+  %p3 = bitcast i32* %arrayidx0 to i8*
+; CHECK: [[GEP:%[0-9]+]] = getelementptr inbounds i8, i8* %p3, i64 4
+; CHECK: call void @llvm.memset.p0i8.i64(i8* [[GEP]], i8 0, i64 24, i32 4, i1 false)
+  call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 28, i32 4, i1 false)
+  %p4 = bitcast i32* %p to i64*
+  store i64 1, i64* %p4, align 8
+  ret void
+}
+
+; We do not trim the beginning of the eariler write if the alignment of the
+; start pointer is changed.
+define void @dontwrite0to3_align8(i32* nocapture %p) {
+; CHECK-LABEL: @dontwrite0to3_align8(
+entry:
+  %p3 = bitcast i32* %p to i8*
+; CHECK: call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 32, i32 8, i1 false)
+  call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 32, i32 8, i1 false)
+  store i32 1, i32* %p, align 4
+  ret void
+}
+
+define void @dontwrite0to1(i32* nocapture %p) {
+; CHECK-LABEL: @dontwrite0to1(
+entry:
+  %p3 = bitcast i32* %p to i8*
+; CHECK: call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 32, i32 4, i1 false)
+  call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 32, i32 4, i1 false)
+  %p4 = bitcast i32* %p to i16*
+  store i16 1, i16* %p4, align 4
+  ret void
+}
+
+define void @dontwrite2to9(i32* nocapture %p) {
+; CHECK-LABEL: @dontwrite2to9(
+entry:
+  %arrayidx0 = getelementptr inbounds i32, i32* %p, i64 1
+  %p3 = bitcast i32* %arrayidx0 to i8*
+; CHECK: call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 32, i32 4, i1 false)
+  call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 32, i32 4, i1 false)
+  %p4 = bitcast i32* %p to i16*
+  %arrayidx2 = getelementptr inbounds i16, i16* %p4, i64 1
+  %p5 = bitcast i16* %arrayidx2 to i64*
+  store i64 1, i64* %p5, align 8
+  ret void
+}
+
+declare void @llvm.memset.p0i8.i64(i8* nocapture, i8, i64, i32, i1) nounwind
+