[DSE] Add value type info checks for masked store candidates in Dead Store Elimination.
authorMichael Berg <michael.berg@sifive.com>
Tue, 20 Sep 2022 22:54:16 +0000 (15:54 -0700)
committerMichael Berg <michael.berg@sifive.com>
Tue, 20 Sep 2022 22:54:25 +0000 (15:54 -0700)
The type information of the store values can diverge when checking for valid
mask store candidates to eliminate via DSE. This patch checks for equivalence
wrt to size and element count.

Reviewed By: fhahn, rui.zhang

Differential Revision: https://reviews.llvm.org/D132700

llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
llvm/test/Transforms/DeadStoreElimination/masked-dead-store.ll

index 45c404a..1ed52cc 100644 (file)
@@ -242,19 +242,30 @@ static OverwriteResult isMaskedStoreOverwrite(const Instruction *KillingI,
   const auto *DeadII = dyn_cast<IntrinsicInst>(DeadI);
   if (KillingII == nullptr || DeadII == nullptr)
     return OW_Unknown;
-  if (KillingII->getIntrinsicID() != Intrinsic::masked_store ||
-      DeadII->getIntrinsicID() != Intrinsic::masked_store)
+  if (KillingII->getIntrinsicID() != DeadII->getIntrinsicID())
     return OW_Unknown;
-  // Pointers.
-  Value *KillingPtr = KillingII->getArgOperand(1)->stripPointerCasts();
-  Value *DeadPtr = DeadII->getArgOperand(1)->stripPointerCasts();
-  if (KillingPtr != DeadPtr && !AA.isMustAlias(KillingPtr, DeadPtr))
-    return OW_Unknown;
-  // Masks.
-  // TODO: check that KillingII's mask is a superset of the DeadII's mask.
-  if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3))
-    return OW_Unknown;
-  return OW_Complete;
+  if (KillingII->getIntrinsicID() == Intrinsic::masked_store) {
+    // Type size.
+    VectorType *KillingTy =
+        cast<VectorType>(KillingII->getArgOperand(0)->getType());
+    VectorType *DeadTy = cast<VectorType>(DeadII->getArgOperand(0)->getType());
+    if (KillingTy->getScalarSizeInBits() != DeadTy->getScalarSizeInBits())
+      return OW_Unknown;
+    // Element count.
+    if (KillingTy->getElementCount() != DeadTy->getElementCount())
+      return OW_Unknown;
+    // Pointers.
+    Value *KillingPtr = KillingII->getArgOperand(1)->stripPointerCasts();
+    Value *DeadPtr = DeadII->getArgOperand(1)->stripPointerCasts();
+    if (KillingPtr != DeadPtr && !AA.isMustAlias(KillingPtr, DeadPtr))
+      return OW_Unknown;
+    // Masks.
+    // TODO: check that KillingII's mask is a superset of the DeadII's mask.
+    if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3))
+      return OW_Unknown;
+    return OW_Complete;
+  }
+  return OW_Unknown;
 }
 
 /// Return 'OW_Complete' if a store to the 'KillingLoc' location completely
index 9c7eea5..5830d45 100644 (file)
@@ -59,7 +59,8 @@ b0:
 
 define dllexport i32 @f1(<4 x i32>* %a, <4 x i8> %v1, <4 x i32> %v2) {
 ; CHECK-LABEL: @f1(
-; CHECK-NEXT:    [[PTR:%.*]] = bitcast <4 x i32>* [[A:%.*]] to <4 x i8>*
+; CHECK-NEXT:    call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> [[V2:%.*]], <4 x i32>* [[A:%.*]], i32 1, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
+; CHECK-NEXT:    [[PTR:%.*]] = bitcast <4 x i32>* [[A]] to <4 x i8>*
 ; CHECK-NEXT:    call void @llvm.masked.store.v4i8.p0v4i8(<4 x i8> [[V1:%.*]], <4 x i8>* [[PTR]], i32 1, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
 ; CHECK-NEXT:    ret i32 0
 ;
@@ -69,6 +70,19 @@ define dllexport i32 @f1(<4 x i32>* %a, <4 x i8> %v1, <4 x i32> %v2) {
   ret i32 0
 }
 
+define dllexport i32 @f2(<4 x i32>* %a, <4 x i8> %v1, <4 x i32> %v2, <4 x i1> %mask) {
+; CHECK-LABEL: @f2(
+; CHECK-NEXT:    call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> [[V2:%.*]], <4 x i32>* [[A:%.*]], i32 1, <4 x i1> [[MASK:%.*]])
+; CHECK-NEXT:    [[PTR:%.*]] = bitcast <4 x i32>* [[A]] to <4 x i8>*
+; CHECK-NEXT:    call void @llvm.masked.store.v4i8.p0v4i8(<4 x i8> [[V1:%.*]], <4 x i8>* [[PTR]], i32 1, <4 x i1> [[MASK]])
+; CHECK-NEXT:    ret i32 0
+;
+  tail call void @llvm.masked.store.v4i32.p0(<4 x i32> %v2, <4 x i32>* %a, i32 1, <4 x i1> %mask)
+  %ptr = bitcast <4 x i32>* %a to <4 x i8>*
+  tail call void @llvm.masked.store.v4i8.p0(<4 x i8> %v1, <4 x i8>* %ptr, i32 1, <4 x i1> %mask)
+  ret i32 0
+}
+
 declare void @llvm.masked.store.v4i8.p0(<4 x i8>, <4 x i8>*, i32, <4 x i1>)
 declare void @llvm.masked.store.v4i32.p0(<4 x i32>, <4 x i32>*, i32, <4 x i1>)