[Analysis] simplify code for getSplatValue(); NFC
authorSanjay Patel <spatel@rotateright.com>
Fri, 7 Jun 2019 16:09:54 +0000 (16:09 +0000)
committerSanjay Patel <spatel@rotateright.com>
Fri, 7 Jun 2019 16:09:54 +0000 (16:09 +0000)
AFAIK, this is only currently called by TTI, but it could be
used from instcombine or CGP to help solve problems like:
https://bugs.llvm.org/show_bug.cgi?id=37428
https://bugs.llvm.org/show_bug.cgi?id=42174

llvm-svn: 362810

llvm/lib/Analysis/VectorUtils.cpp

index f1a05ae..ea00f0a 100644 (file)
@@ -304,30 +304,21 @@ Value *llvm::findScalarElement(Value *V, unsigned EltNo) {
 
 /// Get splat value if the input is a splat vector or return nullptr.
 /// This function is not fully general. It checks only 2 cases:
-/// the input value is (1) a splat constants vector or (2) a sequence
-/// of instructions that broadcast a single value into a vector.
-///
+/// the input value is (1) a splat constant vector or (2) a sequence
+/// of instructions that broadcasts a scalar at element 0.
 const llvm::Value *llvm::getSplatValue(const Value *V) {
-
-  if (auto *C = dyn_cast<Constant>(V))
-    if (isa<VectorType>(V->getType()))
+  if (isa<VectorType>(V->getType()))
+    if (auto *C = dyn_cast<Constant>(V))
       return C->getSplatValue();
 
-  auto *ShuffleInst = dyn_cast<ShuffleVectorInst>(V);
-  if (!ShuffleInst)
-    return nullptr;
-  // All-zero (or undef) shuffle mask elements.
-  for (int MaskElt : ShuffleInst->getShuffleMask())
-    if (MaskElt != 0 && MaskElt != -1)
-      return nullptr;
-  // The first shuffle source is 'insertelement' with index 0.
-  auto *InsertEltInst =
-    dyn_cast<InsertElementInst>(ShuffleInst->getOperand(0));
-  if (!InsertEltInst || !isa<ConstantInt>(InsertEltInst->getOperand(2)) ||
-      !cast<ConstantInt>(InsertEltInst->getOperand(2))->isZero())
-    return nullptr;
+  // shuf (inselt ?, Splat, 0), ?, <0, undef, 0, ...>
+  Value *Splat;
+  if (match(V, m_ShuffleVector(m_InsertElement(m_Value(), m_Value(Splat),
+                                               m_ZeroInt()),
+                               m_Value(), m_ZeroInt())))
+    return Splat;
 
-  return InsertEltInst->getOperand(1);
+  return nullptr;
 }
 
 MapVector<Instruction *, uint64_t>