From aacfe2be53d441d256091b2b495875a69fc2f285 Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Sat, 3 Oct 2020 16:26:29 +0100 Subject: [PATCH] [InstCombine] recognizeBSwapOrBitReverseIdiom - add vector support Add basic vector handling to recognizeBSwapOrBitReverseIdiom/collectBitParts - this works at the element level, all vector element operations must match (splat constants etc.) and there is no cross-element support (insert/extract/shuffle etc.). --- llvm/lib/Transforms/Utils/Local.cpp | 24 ++++++++----- llvm/test/Transforms/InstCombine/bswap.ll | 59 +++++++------------------------ 2 files changed, 27 insertions(+), 56 deletions(-) diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp index d17ce2f..eea347a 100644 --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -2803,7 +2803,7 @@ struct BitPart { /// Analyze the specified subexpression and see if it is capable of providing /// pieces of a bswap or bitreverse. The subexpression provides a potential -/// piece of a bswap or bitreverse if it can be proven that each non-zero bit in +/// piece of a bswap or bitreverse if it can be proved that each non-zero bit in /// the output of the expression came from a corresponding bit in some other /// value. This function is recursive, and the end result is a mapping of /// bitnumber to bitnumber. It is the caller's responsibility to validate that @@ -2815,6 +2815,10 @@ struct BitPart { /// BitPart is returned with Provider set to %X and Provenance[24-31] set to /// [0-7]. /// +/// For vector types, all analysis is performed at the per-element level. No +/// cross-element analysis is supported (shuffle/insertion/reduction), and all +/// constant masks must be splatted across all elements. +/// /// To avoid revisiting values, the BitPart results are memoized into the /// provided map. To avoid unnecessary copying of BitParts, BitParts are /// constructed in-place in the \c BPS map. Because of this \c BPS needs to @@ -3019,14 +3023,14 @@ bool llvm::recognizeBSwapOrBitReverseIdiom( return false; if (!MatchBSwaps && !MatchBitReversals) return false; - IntegerType *ITy = dyn_cast(I->getType()); - if (!ITy || ITy->getBitWidth() > 128) - return false; // Can't do vectors or integers > 128 bits. + Type *ITy = I->getType(); + if (!ITy->isIntOrIntVectorTy() || ITy->getScalarSizeInBits() > 128) + return false; // Can't do integer/elements > 128 bits. - IntegerType *DemandedTy = ITy; + Type *DemandedTy = ITy; if (I->hasOneUse()) if (auto *Trunc = dyn_cast(I->user_back())) - DemandedTy = cast(Trunc->getType()); + DemandedTy = Trunc->getType(); // Try to find all the pieces corresponding to the bswap. std::map> BPS; @@ -3044,12 +3048,14 @@ bool llvm::recognizeBSwapOrBitReverseIdiom( BitProvenance = BitProvenance.drop_back(); if (BitProvenance.empty()) return false; // TODO - handle null value? - DemandedTy = IntegerType::get(I->getContext(), BitProvenance.size()); + DemandedTy = Type::getIntNTy(I->getContext(), BitProvenance.size()); + if (auto *IVecTy = dyn_cast(ITy)) + DemandedTy = VectorType::get(DemandedTy, IVecTy); } // Check BitProvenance hasn't found a source larger than the result type. - unsigned DemandedBW = DemandedTy->getBitWidth(); - if (DemandedBW > ITy->getBitWidth()) + unsigned DemandedBW = DemandedTy->getScalarSizeInBits(); + if (DemandedBW > ITy->getScalarSizeInBits()) return false; // Now, is the bit permutation correct for a bswap or a bitreverse? We can diff --git a/llvm/test/Transforms/InstCombine/bswap.ll b/llvm/test/Transforms/InstCombine/bswap.ll index d6f0792..effbc66 100644 --- a/llvm/test/Transforms/InstCombine/bswap.ll +++ b/llvm/test/Transforms/InstCombine/bswap.ll @@ -22,15 +22,7 @@ define i32 @test1(i32 %i) { define <2 x i32> @test1_vector(<2 x i32> %i) { ; CHECK-LABEL: @test1_vector( -; CHECK-NEXT: [[T1:%.*]] = lshr <2 x i32> [[I:%.*]], -; CHECK-NEXT: [[T3:%.*]] = lshr <2 x i32> [[I]], -; CHECK-NEXT: [[T4:%.*]] = and <2 x i32> [[T3]], -; CHECK-NEXT: [[T5:%.*]] = or <2 x i32> [[T1]], [[T4]] -; CHECK-NEXT: [[T7:%.*]] = shl <2 x i32> [[I]], -; CHECK-NEXT: [[T8:%.*]] = and <2 x i32> [[T7]], -; CHECK-NEXT: [[T9:%.*]] = or <2 x i32> [[T5]], [[T8]] -; CHECK-NEXT: [[T11:%.*]] = shl <2 x i32> [[I]], -; CHECK-NEXT: [[T12:%.*]] = or <2 x i32> [[T9]], [[T11]] +; CHECK-NEXT: [[T12:%.*]] = call <2 x i32> @llvm.bswap.v2i32(<2 x i32> [[I:%.*]]) ; CHECK-NEXT: ret <2 x i32> [[T12]] ; %t1 = lshr <2 x i32> %i, @@ -64,15 +56,7 @@ define i32 @test2(i32 %arg) { define <2 x i32> @test2_vector(<2 x i32> %arg) { ; CHECK-LABEL: @test2_vector( -; CHECK-NEXT: [[T2:%.*]] = shl <2 x i32> [[ARG:%.*]], -; CHECK-NEXT: [[T4:%.*]] = shl <2 x i32> [[ARG]], -; CHECK-NEXT: [[T5:%.*]] = and <2 x i32> [[T4]], -; CHECK-NEXT: [[T6:%.*]] = or <2 x i32> [[T2]], [[T5]] -; CHECK-NEXT: [[T8:%.*]] = lshr <2 x i32> [[ARG]], -; CHECK-NEXT: [[T9:%.*]] = and <2 x i32> [[T8]], -; CHECK-NEXT: [[T10:%.*]] = or <2 x i32> [[T6]], [[T9]] -; CHECK-NEXT: [[T12:%.*]] = lshr <2 x i32> [[ARG]], -; CHECK-NEXT: [[T14:%.*]] = or <2 x i32> [[T10]], [[T12]] +; CHECK-NEXT: [[T14:%.*]] = call <2 x i32> @llvm.bswap.v2i32(<2 x i32> [[ARG:%.*]]) ; CHECK-NEXT: ret <2 x i32> [[T14]] ; %t2 = shl <2 x i32> %arg, @@ -225,15 +209,7 @@ define i32 @test6(i32 %x) nounwind readnone { define <2 x i32> @test6_vector(<2 x i32> %x) nounwind readnone { ; CHECK-LABEL: @test6_vector( -; CHECK-NEXT: [[T:%.*]] = shl <2 x i32> [[X:%.*]], -; CHECK-NEXT: [[X_MASK:%.*]] = and <2 x i32> [[X]], -; CHECK-NEXT: [[T1:%.*]] = lshr <2 x i32> [[X]], -; CHECK-NEXT: [[T2:%.*]] = and <2 x i32> [[T1]], -; CHECK-NEXT: [[T3:%.*]] = or <2 x i32> [[X_MASK]], [[T]] -; CHECK-NEXT: [[T4:%.*]] = or <2 x i32> [[T3]], [[T2]] -; CHECK-NEXT: [[T5:%.*]] = shl <2 x i32> [[T4]], -; CHECK-NEXT: [[T6:%.*]] = lshr <2 x i32> [[X]], -; CHECK-NEXT: [[T7:%.*]] = or <2 x i32> [[T5]], [[T6]] +; CHECK-NEXT: [[T7:%.*]] = call <2 x i32> @llvm.bswap.v2i32(<2 x i32> [[X:%.*]]) ; CHECK-NEXT: ret <2 x i32> [[T7]] ; %t = shl <2 x i32> %x, @@ -381,12 +357,9 @@ define i16 @test10(i32 %a) { define <2 x i16> @test10_vector(<2 x i32> %a) { ; CHECK-LABEL: @test10_vector( -; CHECK-NEXT: [[SHR1:%.*]] = lshr <2 x i32> [[A:%.*]], -; CHECK-NEXT: [[AND1:%.*]] = and <2 x i32> [[SHR1]], -; CHECK-NEXT: [[AND2:%.*]] = shl <2 x i32> [[A]], -; CHECK-NEXT: [[OR:%.*]] = or <2 x i32> [[AND1]], [[AND2]] -; CHECK-NEXT: [[CONV:%.*]] = trunc <2 x i32> [[OR]] to <2 x i16> -; CHECK-NEXT: ret <2 x i16> [[CONV]] +; CHECK-NEXT: [[TRUNC:%.*]] = trunc <2 x i32> [[A:%.*]] to <2 x i16> +; CHECK-NEXT: [[REV:%.*]] = call <2 x i16> @llvm.bswap.v2i16(<2 x i16> [[TRUNC]]) +; CHECK-NEXT: ret <2 x i16> [[REV]] ; %shr1 = lshr <2 x i32> %a, %and1 = and <2 x i32> %shr1, @@ -457,12 +430,10 @@ define i64 @PR39793_bswap_u64_as_u16(i64 %0) { define <2 x i64> @PR39793_bswap_u64_as_u16_vector(<2 x i64> %0) { ; CHECK-LABEL: @PR39793_bswap_u64_as_u16_vector( -; CHECK-NEXT: [[TMP2:%.*]] = lshr <2 x i64> [[TMP0:%.*]], -; CHECK-NEXT: [[TMP3:%.*]] = and <2 x i64> [[TMP2]], -; CHECK-NEXT: [[TMP4:%.*]] = shl <2 x i64> [[TMP0]], -; CHECK-NEXT: [[TMP5:%.*]] = and <2 x i64> [[TMP4]], -; CHECK-NEXT: [[TMP6:%.*]] = or <2 x i64> [[TMP3]], [[TMP5]] -; CHECK-NEXT: ret <2 x i64> [[TMP6]] +; CHECK-NEXT: [[TRUNC:%.*]] = trunc <2 x i64> [[TMP0:%.*]] to <2 x i16> +; CHECK-NEXT: [[REV:%.*]] = call <2 x i16> @llvm.bswap.v2i16(<2 x i16> [[TRUNC]]) +; CHECK-NEXT: [[TMP2:%.*]] = zext <2 x i16> [[REV]] to <2 x i64> +; CHECK-NEXT: ret <2 x i64> [[TMP2]] ; %2 = lshr <2 x i64> %0, %3 = and <2 x i64> %2, @@ -550,14 +521,8 @@ declare i32 @llvm.bswap.i32(i32) define <2 x i32> @partial_bswap_vector(<2 x i32> %x) { ; CHECK-LABEL: @partial_bswap_vector( -; CHECK-NEXT: [[X3:%.*]] = shl <2 x i32> [[X:%.*]], -; CHECK-NEXT: [[A2:%.*]] = shl <2 x i32> [[X]], -; CHECK-NEXT: [[X2:%.*]] = and <2 x i32> [[A2]], -; CHECK-NEXT: [[X32:%.*]] = or <2 x i32> [[X3]], [[X2]] -; CHECK-NEXT: [[T1:%.*]] = and <2 x i32> [[X]], -; CHECK-NEXT: [[T2:%.*]] = call <2 x i32> @llvm.bswap.v2i32(<2 x i32> [[T1]]) -; CHECK-NEXT: [[R:%.*]] = or <2 x i32> [[X32]], [[T2]] -; CHECK-NEXT: ret <2 x i32> [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i32> @llvm.bswap.v2i32(<2 x i32> [[X:%.*]]) +; CHECK-NEXT: ret <2 x i32> [[TMP1]] ; %x3 = shl <2 x i32> %x, %a2 = shl <2 x i32> %x, -- 2.7.4