Simplify llvm.masked.load w/ undef masks
authorDavid Majnemer <david.majnemer@gmail.com>
Thu, 14 Jul 2016 06:58:37 +0000 (06:58 +0000)
committerDavid Majnemer <david.majnemer@gmail.com>
Thu, 14 Jul 2016 06:58:37 +0000 (06:58 +0000)
We can always pick the passthru value if the mask is undef: we are
permitted to treat the mask as-if it were filled with zeros.

llvm-svn: 275379

llvm/lib/Analysis/ConstantFolding.cpp
llvm/lib/Analysis/InstructionSimplify.cpp
llvm/test/Transforms/InstSimplify/call.ll

index 96a2d02..6c471ab 100644 (file)
@@ -1854,32 +1854,39 @@ Constant *ConstantFoldVectorCall(StringRef Name, unsigned IntrinsicID,
     auto *SrcPtr = Operands[0];
     auto *Mask = Operands[2];
     auto *Passthru = Operands[3];
+
     Constant *VecData = ConstantFoldLoadFromConstPtr(SrcPtr, VTy, DL);
-    if (!VecData)
-      return nullptr;
 
     SmallVector<Constant *, 32> NewElements;
     for (unsigned I = 0, E = VTy->getNumElements(); I != E; ++I) {
-      auto *MaskElt =
-          dyn_cast_or_null<ConstantInt>(Mask->getAggregateElement(I));
+      auto *MaskElt = Mask->getAggregateElement(I);
       if (!MaskElt)
         break;
-      if (MaskElt->isZero()) {
-        auto *PassthruElt = Passthru->getAggregateElement(I);
+      auto *PassthruElt = Passthru->getAggregateElement(I);
+      auto *VecElt = VecData ? VecData->getAggregateElement(I) : nullptr;
+      if (isa<UndefValue>(MaskElt)) {
+        if (PassthruElt)
+          NewElements.push_back(PassthruElt);
+        else if (VecElt)
+          NewElements.push_back(VecElt);
+        else
+          return nullptr;
+      }
+      if (MaskElt->isNullValue()) {
         if (!PassthruElt)
-          break;
+          return nullptr;
         NewElements.push_back(PassthruElt);
-      } else {
-        assert(MaskElt->isOne());
-        auto *VecElt = VecData->getAggregateElement(I);
+      } else if (MaskElt->isOneValue()) {
         if (!VecElt)
-          break;
+          return nullptr;
         NewElements.push_back(VecElt);
+      } else {
+        return nullptr;
       }
     }
-    if (NewElements.size() == VTy->getNumElements())
-      return ConstantVector::get(NewElements);
-    return nullptr;
+    if (NewElements.size() != VTy->getNumElements())
+      return nullptr;
+    return ConstantVector::get(NewElements);
   }
 
   for (unsigned I = 0, E = VTy->getNumElements(); I != E; ++I) {
index 609cd26..0cb2c78 100644 (file)
@@ -3944,6 +3944,22 @@ static Value *SimplifyRelativeLoad(Constant *Ptr, Constant *Offset,
   return ConstantExpr::getBitCast(LoadedLHSPtr, Int8PtrTy);
 }
 
+static bool maskIsAllZeroOrUndef(Value *Mask) {
+  auto *ConstMask = dyn_cast<Constant>(Mask);
+  if (!ConstMask)
+    return false;
+  if (ConstMask->isNullValue() || isa<UndefValue>(ConstMask))
+    return true;
+  for (unsigned I = 0, E = ConstMask->getType()->getVectorNumElements(); I != E;
+       ++I) {
+    if (auto *MaskElt = ConstMask->getAggregateElement(I))
+      if (MaskElt->isNullValue() || isa<UndefValue>(MaskElt))
+        continue;
+    return false;
+  }
+  return true;
+}
+
 template <typename IterTy>
 static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd,
                                 const Query &Q, unsigned MaxRecurse) {
@@ -3993,11 +4009,11 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd,
 
   // Simplify calls to llvm.masked.load.*
   if (IID == Intrinsic::masked_load) {
-    IterTy MaskArg = ArgBegin + 2;
-    // If the mask is all zeros, the "passthru" argument is the result.
-    if (auto *ConstMask = dyn_cast<Constant>(*MaskArg))
-      if (ConstMask->isNullValue())
-        return ArgBegin[3];
+    Value *MaskArg = ArgBegin[2];
+    Value *PassthruArg = ArgBegin[3];
+    // If the mask is all zeros or undef, the "passthru" argument is the result.
+    if (maskIsAllZeroOrUndef(MaskArg))
+      return PassthruArg;
   }
 
   // Perform idempotent optimizations
index e0a071a..988ec2b 100644 (file)
@@ -213,6 +213,13 @@ define <8 x i32> @partial_masked_load() {
   ret <8 x i32> %masked.load
 }
 
+define <8 x i32> @masked_load_undef_mask(<8 x i32>* %V) {
+; CHECK-LABEL: @masked_load_undef_mask(
+; CHECK:         ret <8 x i32> <i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0>
+  %masked.load = call <8 x i32> @llvm.masked.load.v8i32.p0v8i32(<8 x i32>* %V, i32 4, <8 x i1> undef, <8 x i32> <i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0>)
+  ret <8 x i32> %masked.load
+}
+
 declare noalias i8* @malloc(i64)
 
 declare <8 x i32> @llvm.masked.load.v8i32.p0v8i32(<8 x i32>*, i32, <8 x i1>, <8 x i32>)