From f03b069c5b70b59a9cb391a4c41250083aa6b2b4 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Wed, 2 Nov 2022 09:15:11 -0400 Subject: [PATCH] [InstCombine] fold mul with decremented "shl -1" factor (2nd try) This is a corrected version of: bc886e9b587b I made a copy-paste error that created an "add" instead of the intended "sub" on that attempt. The regression tests showed the bug, but I overlooked that. As I said in a comment on issue #58717, the bug reports resulting from the botched patch confirm that the pattern does occur in many real-world applications, so hopefully eliminating the multiply results in better code. I added one more regression test in this version of the patch, and here's an Alive2 proof to show that exact example: https://alive2.llvm.org/ce/z/dge7VC Original commit message: This is a sibling to: 6064e92b0a84 ...but we canonicalize the shl+add to shl+xor, so the pattern is different than I expected: https://alive2.llvm.org/ce/z/8CX16e I have not found any patterns that are safe to propagate no-wrap, so that is not included here. Differential Revision: https://reviews.llvm.org/D137157 --- .../InstCombine/InstCombineMulDivRem.cpp | 13 ++++++- llvm/test/Transforms/InstCombine/mul.ll | 44 ++++++++++++---------- 2 files changed, 36 insertions(+), 21 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index e4fccda..abc88e3 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -140,7 +140,7 @@ static Value *foldMulSelectToNegate(BinaryOperator &I, return nullptr; } -/// Reduce integer multiplication patterns that contain a (1 << Z) factor. +/// Reduce integer multiplication patterns that contain a (+/-1 << Z) factor. /// Callers are expected to call this twice to handle commuted patterns. static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands, InstCombiner::BuilderTy &Builder) { @@ -171,6 +171,17 @@ static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands, return Builder.CreateAdd(Shl, FrX, Mul.getName(), HasNUW, PropagateNSW); } + // Similar to above, but a decrement of the shifted value is disguised as + // 'not' and becomes a sub: + // X * (~(-1 << Z)) --> X * ((1 << Z) - 1) --> (X << Z) - X + // This increases uses of X, so it may require a freeze, but that is still + // expected to be an improvement because it removes the multiply. + if (match(Y, m_OneUse(m_Not(m_OneUse(m_Shl(m_AllOnes(), m_Value(Z))))))) { + Value *FrX = Builder.CreateFreeze(X, X->getName() + ".fr"); + Value *Shl = Builder.CreateShl(FrX, Z, "mulshl"); + return Builder.CreateSub(Shl, FrX, Mul.getName()); + } + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/mul.ll b/llvm/test/Transforms/InstCombine/mul.ll index 5f0e6ce..4cb2468 100644 --- a/llvm/test/Transforms/InstCombine/mul.ll +++ b/llvm/test/Transforms/InstCombine/mul.ll @@ -227,12 +227,14 @@ define i32 @shl1_increment_use(i32 %x, i32 %y) { ret i32 %m } +; ((-1 << x) ^ -1) * y --> (y << x) - y + define i8 @shl1_decrement(i8 %x, i8 %y) { ; CHECK-LABEL: @shl1_decrement( -; CHECK-NEXT: [[POW2X:%.*]] = shl i8 -1, [[X:%.*]] -; CHECK-NEXT: [[X1:%.*]] = xor i8 [[POW2X]], -1 -; CHECK-NEXT: [[M:%.*]] = mul i8 [[X1]], [[Y:%.*]] -; CHECK-NEXT: ret i8 [[M]] +; CHECK-NEXT: [[Y_FR:%.*]] = freeze i8 [[Y:%.*]] +; CHECK-NEXT: [[MULSHL:%.*]] = shl i8 [[Y_FR]], [[X:%.*]] +; CHECK-NEXT: [[M1:%.*]] = sub i8 [[MULSHL]], [[Y_FR]] +; CHECK-NEXT: ret i8 [[M1]] ; %pow2x = shl i8 -1, %x %x1 = xor i8 %pow2x, -1 @@ -243,10 +245,9 @@ define i8 @shl1_decrement(i8 %x, i8 %y) { define i8 @shl1_decrement_commute(i8 %x, i8 noundef %p) { ; CHECK-LABEL: @shl1_decrement_commute( ; CHECK-NEXT: [[Y:%.*]] = ashr i8 [[P:%.*]], 1 -; CHECK-NEXT: [[NOTMASK:%.*]] = shl nsw i8 -1, [[X:%.*]] -; CHECK-NEXT: [[X1:%.*]] = xor i8 [[NOTMASK]], -1 -; CHECK-NEXT: [[M:%.*]] = mul i8 [[Y]], [[X1]] -; CHECK-NEXT: ret i8 [[M]] +; CHECK-NEXT: [[MULSHL:%.*]] = shl i8 [[Y]], [[X:%.*]] +; CHECK-NEXT: [[M1:%.*]] = sub i8 [[MULSHL]], [[Y]] +; CHECK-NEXT: ret i8 [[M1]] ; %y = ashr i8 %p, 1 ; thwart complexity-based canonicalization %pow2x = shl i8 1, %x @@ -257,10 +258,10 @@ define i8 @shl1_decrement_commute(i8 %x, i8 noundef %p) { define i8 @shl1_nuw_decrement(i8 %x, i8 %y) { ; CHECK-LABEL: @shl1_nuw_decrement( -; CHECK-NEXT: [[POW2X:%.*]] = shl i8 -1, [[X:%.*]] -; CHECK-NEXT: [[X1:%.*]] = xor i8 [[POW2X]], -1 -; CHECK-NEXT: [[M:%.*]] = mul nuw i8 [[X1]], [[Y:%.*]] -; CHECK-NEXT: ret i8 [[M]] +; CHECK-NEXT: [[Y_FR:%.*]] = freeze i8 [[Y:%.*]] +; CHECK-NEXT: [[MULSHL:%.*]] = shl i8 [[Y_FR]], [[X:%.*]] +; CHECK-NEXT: [[M1:%.*]] = sub i8 [[MULSHL]], [[Y_FR]] +; CHECK-NEXT: ret i8 [[M1]] ; %pow2x = shl i8 -1, %x %x1 = xor i8 %pow2x, -1 @@ -270,10 +271,10 @@ define i8 @shl1_nuw_decrement(i8 %x, i8 %y) { define i8 @shl1_nsw_decrement(i8 %x, i8 %y) { ; CHECK-LABEL: @shl1_nsw_decrement( -; CHECK-NEXT: [[POW2X:%.*]] = shl nsw i8 -1, [[X:%.*]] -; CHECK-NEXT: [[X1:%.*]] = xor i8 [[POW2X]], -1 -; CHECK-NEXT: [[M:%.*]] = mul nsw i8 [[X1]], [[Y:%.*]] -; CHECK-NEXT: ret i8 [[M]] +; CHECK-NEXT: [[Y_FR:%.*]] = freeze i8 [[Y:%.*]] +; CHECK-NEXT: [[MULSHL:%.*]] = shl i8 [[Y_FR]], [[X:%.*]] +; CHECK-NEXT: [[M1:%.*]] = sub i8 [[MULSHL]], [[Y_FR]] +; CHECK-NEXT: ret i8 [[M1]] ; %pow2x = shl nsw i8 -1, %x %x1 = xor i8 %pow2x, -1 @@ -281,6 +282,8 @@ define i8 @shl1_nsw_decrement(i8 %x, i8 %y) { ret i8 %m } +; negative test - extra use would require more instructions + define i32 @shl1_decrement_use(i32 %x, i32 %y) { ; CHECK-LABEL: @shl1_decrement_use( ; CHECK-NEXT: [[NOTMASK:%.*]] = shl nsw i32 -1, [[X:%.*]] @@ -296,12 +299,13 @@ define i32 @shl1_decrement_use(i32 %x, i32 %y) { ret i32 %m } +; the fold works for vectors too and if 'y' is a constant, sub becomes add + define <2 x i8> @shl1_decrement_vec(<2 x i8> %x) { ; CHECK-LABEL: @shl1_decrement_vec( -; CHECK-NEXT: [[POW2X:%.*]] = shl <2 x i8> , [[X:%.*]] -; CHECK-NEXT: [[X1:%.*]] = xor <2 x i8> [[POW2X]], -; CHECK-NEXT: [[M:%.*]] = mul <2 x i8> [[X1]], -; CHECK-NEXT: ret <2 x i8> [[M]] +; CHECK-NEXT: [[MULSHL:%.*]] = shl <2 x i8> , [[X:%.*]] +; CHECK-NEXT: [[M1:%.*]] = add <2 x i8> [[MULSHL]], +; CHECK-NEXT: ret <2 x i8> [[M1]] ; %pow2x = shl <2 x i8> , %x %x1 = xor <2 x i8> %pow2x, -- 2.7.4