From 9b9e2da07dd3b103e5a41a3519d839117d994ffa Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Sun, 2 Feb 2020 09:16:42 -0500 Subject: [PATCH] [Analysis] add optional index parameter to isSplatValue() We want to allow splat value transforms to improve PR44588 and related bugs: https://bugs.llvm.org/show_bug.cgi?id=44588 ...but to do that, we need to know if values are splatted from the same, specific index (lane) rather than splatted from an arbitrary index. We can improve the undef handling with 1-liner follow-ups because the Constant API optionally allow undefs now. Differential Revision: https://reviews.llvm.org/D73549 --- llvm/include/llvm/Analysis/VectorUtils.h | 8 +- llvm/lib/Analysis/VectorUtils.cpp | 29 ++++--- llvm/unittests/Analysis/VectorUtilsTest.cpp | 118 +++++++++++++++++++++++++++- 3 files changed, 142 insertions(+), 13 deletions(-) diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h index f0b0f15..7726cf0 100644 --- a/llvm/include/llvm/Analysis/VectorUtils.h +++ b/llvm/include/llvm/Analysis/VectorUtils.h @@ -306,11 +306,13 @@ Value *findScalarElement(Value *V, unsigned EltNo); /// a sequence of instructions that broadcast a single value into a vector. const Value *getSplatValue(const Value *V); -/// Return true if the input value is known to be a vector with all identical -/// elements (potentially including undefined elements). +/// Return true if each element of the vector value \p V is poisoned or equal to +/// every other non-poisoned element. If an index element is specified, either +/// every element of the vector is poisoned or the element at that index is not +/// poisoned and equal to every other non-poisoned element. /// This may be more powerful than the related getSplatValue() because it is /// not limited by finding a scalar source value to a splatted vector. -bool isSplatValue(const Value *V, unsigned Depth = 0); +bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0); /// Compute a map of integer instructions to their minimum legal type /// size. diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp index 0fb09e3..e4b0010 100644 --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -330,21 +330,32 @@ const llvm::Value *llvm::getSplatValue(const Value *V) { // adjusted if needed. const unsigned MaxDepth = 6; -bool llvm::isSplatValue(const Value *V, unsigned Depth) { +bool llvm::isSplatValue(const Value *V, int Index, unsigned Depth) { assert(Depth <= MaxDepth && "Limit Search Depth"); if (isa(V->getType())) { if (isa(V)) return true; - // FIXME: Constant splat analysis does not allow undef elements. + // FIXME: We can allow undefs, but if Index was specified, we may want to + // check that the constant is defined at that index. if (auto *C = dyn_cast(V)) return C->getSplatValue() != nullptr; } - // FIXME: Constant splat analysis does not allow undef elements. - Constant *Mask; - if (match(V, m_ShuffleVector(m_Value(), m_Value(), m_Constant(Mask)))) - return Mask->getSplatValue() != nullptr; + if (auto *Shuf = dyn_cast(V)) { + // FIXME: We can safely allow undefs here. If Index was specified, we will + // check that the mask elt is defined at the required index. + if (!Shuf->getMask()->getSplatValue()) + return false; + + // Match any index. + if (Index == -1) + return true; + + // Match a specific element. The mask should be defined at and match the + // specified index. + return Shuf->getMaskValue(Index) == Index; + } // The remaining tests are all recursive, so bail out if we hit the limit. if (Depth++ == MaxDepth) @@ -353,12 +364,12 @@ bool llvm::isSplatValue(const Value *V, unsigned Depth) { // If both operands of a binop are splats, the result is a splat. Value *X, *Y, *Z; if (match(V, m_BinOp(m_Value(X), m_Value(Y)))) - return isSplatValue(X, Depth) && isSplatValue(Y, Depth); + return isSplatValue(X, Index, Depth) && isSplatValue(Y, Index, Depth); // If all operands of a select are splats, the result is a splat. if (match(V, m_Select(m_Value(X), m_Value(Y), m_Value(Z)))) - return isSplatValue(X, Depth) && isSplatValue(Y, Depth) && - isSplatValue(Z, Depth); + return isSplatValue(X, Index, Depth) && isSplatValue(Y, Index, Depth) && + isSplatValue(Z, Index, Depth); // TODO: Add support for unary ops (fneg), casts, intrinsics (overflow ops). diff --git a/llvm/unittests/Analysis/VectorUtilsTest.cpp b/llvm/unittests/Analysis/VectorUtilsTest.cpp index 0743160..ea5282f 100644 --- a/llvm/unittests/Analysis/VectorUtilsTest.cpp +++ b/llvm/unittests/Analysis/VectorUtilsTest.cpp @@ -107,6 +107,24 @@ TEST_F(VectorUtilsTest, isSplatValue_00) { EXPECT_TRUE(isSplatValue(A)); } +TEST_F(VectorUtilsTest, isSplatValue_00_index0) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> zeroinitializer\n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_TRUE(isSplatValue(A, 0)); +} + +TEST_F(VectorUtilsTest, isSplatValue_00_index1) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> zeroinitializer\n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 1)); +} + TEST_F(VectorUtilsTest, isSplatValue_11) { parseAssembly( "define <2 x i8> @test(<2 x i8> %x) {\n" @@ -116,6 +134,24 @@ TEST_F(VectorUtilsTest, isSplatValue_11) { EXPECT_TRUE(isSplatValue(A)); } +TEST_F(VectorUtilsTest, isSplatValue_11_index0) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 0)); +} + +TEST_F(VectorUtilsTest, isSplatValue_11_index1) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_TRUE(isSplatValue(A, 1)); +} + TEST_F(VectorUtilsTest, isSplatValue_01) { parseAssembly( "define <2 x i8> @test(<2 x i8> %x) {\n" @@ -125,7 +161,25 @@ TEST_F(VectorUtilsTest, isSplatValue_01) { EXPECT_FALSE(isSplatValue(A)); } -// FIXME: Constant (mask) splat analysis does not allow undef elements. +TEST_F(VectorUtilsTest, isSplatValue_01_index0) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 0)); +} + +TEST_F(VectorUtilsTest, isSplatValue_01_index1) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 1)); +} + +// FIXME: Allow undef matching with Constant (mask) splat analysis. TEST_F(VectorUtilsTest, isSplatValue_0u) { parseAssembly( @@ -136,6 +190,26 @@ TEST_F(VectorUtilsTest, isSplatValue_0u) { EXPECT_FALSE(isSplatValue(A)); } +// FIXME: Allow undef matching with Constant (mask) splat analysis. + +TEST_F(VectorUtilsTest, isSplatValue_0u_index0) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 0)); +} + +TEST_F(VectorUtilsTest, isSplatValue_0u_index1) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 1)); +} + TEST_F(VectorUtilsTest, isSplatValue_Binop) { parseAssembly( "define <2 x i8> @test(<2 x i8> %x) {\n" @@ -147,6 +221,28 @@ TEST_F(VectorUtilsTest, isSplatValue_Binop) { EXPECT_TRUE(isSplatValue(A)); } +TEST_F(VectorUtilsTest, isSplatValue_Binop_index0) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %v0 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " %A = udiv <2 x i8> %v0, %v1\n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 0)); +} + +TEST_F(VectorUtilsTest, isSplatValue_Binop_index1) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %v0 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " %A = udiv <2 x i8> %v0, %v1\n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 1)); +} + TEST_F(VectorUtilsTest, isSplatValue_Binop_ConstantOp0) { parseAssembly( "define <2 x i8> @test(<2 x i8> %x) {\n" @@ -157,6 +253,26 @@ TEST_F(VectorUtilsTest, isSplatValue_Binop_ConstantOp0) { EXPECT_TRUE(isSplatValue(A)); } +TEST_F(VectorUtilsTest, isSplatValue_Binop_ConstantOp0_index0) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " %A = ashr <2 x i8> , %v1\n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 0)); +} + +TEST_F(VectorUtilsTest, isSplatValue_Binop_ConstantOp0_index1) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " %A = ashr <2 x i8> , %v1\n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_TRUE(isSplatValue(A, 1)); +} + TEST_F(VectorUtilsTest, isSplatValue_Binop_Not_Op0) { parseAssembly( "define <2 x i8> @test(<2 x i8> %x) {\n" -- 2.7.4