[InstCombine] fold multiply by signbit-splat to cmp+select
authorSanjay Patel <spatel@rotateright.com>
Fri, 27 May 2022 15:27:14 +0000 (11:27 -0400)
committerSanjay Patel <spatel@rotateright.com>
Fri, 27 May 2022 15:54:19 +0000 (11:54 -0400)
(ashr i32 X, 31) * C --> (X < 0) ? -C : 0
https://alive2.llvm.org/ce/z/G8u9SS

With a constant operand, this is an improvement in IR
and codegen (where it can be converted to a mask op).

Without a constant operand, we would have to negate
the operand, so that is probably better left to the backend.

This is similar but not the same optimization that is requested
in #55618.

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
llvm/test/Transforms/InstCombine/mul.ll

index a2feacc..897580a 100644 (file)
@@ -356,12 +356,22 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
   if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))
     return SelectInst::Create(X, Op0, ConstantInt::getNullValue(Ty));
 
-  // (sext bool X) * C --> X ? -C : 0
   Constant *ImmC;
-  if (match(Op0, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1) &&
-      match(Op1, m_ImmConstant(ImmC))) {
-    Constant *NegC = ConstantExpr::getNeg(ImmC);
-    return SelectInst::Create(X, NegC, ConstantInt::getNullValue(Ty));
+  if (match(Op1, m_ImmConstant(ImmC))) {
+    // (sext bool X) * C --> X ? -C : 0
+    if (match(Op0, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
+      Constant *NegC = ConstantExpr::getNeg(ImmC);
+      return SelectInst::Create(X, NegC, ConstantInt::getNullValue(Ty));
+    }
+
+    // (ashr i32 X, 31) * C --> (X < 0) ? -C : 0
+    const APInt *C;
+    if (match(Op0, m_OneUse(m_AShr(m_Value(X), m_APInt(C)))) &&
+        *C == C->getBitWidth() - 1) {
+      Constant *NegC = ConstantExpr::getNeg(ImmC);
+      Value *IsNeg = Builder.CreateIsNeg(X, "isneg");
+      return SelectInst::Create(IsNeg, NegC, ConstantInt::getNullValue(Ty));
+    }
   }
 
   // (lshr X, 31) * Y --> (X < 0) ? Y : 0
index 47fd5fc..895ce43 100644 (file)
@@ -464,8 +464,8 @@ define <2 x i32> @signbit_mul_vec_commute(<2 x i32> %a, <2 x i32> %b) {
 
 define i32 @signsplat_mul(i32 %x) {
 ; CHECK-LABEL: @signsplat_mul(
-; CHECK-NEXT:    [[ASH:%.*]] = ashr i32 [[X:%.*]], 31
-; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[ASH]], 42
+; CHECK-NEXT:    [[ISNEG:%.*]] = icmp slt i32 [[X:%.*]], 0
+; CHECK-NEXT:    [[MUL:%.*]] = select i1 [[ISNEG]], i32 -42, i32 0
 ; CHECK-NEXT:    ret i32 [[MUL]]
 ;
   %ash = ashr i32 %x, 31
@@ -475,8 +475,8 @@ define i32 @signsplat_mul(i32 %x) {
 
 define <2 x i32> @signsplat_mul_vec(<2 x i32> %x) {
 ; CHECK-LABEL: @signsplat_mul_vec(
-; CHECK-NEXT:    [[ASH:%.*]] = ashr <2 x i32> [[X:%.*]], <i32 31, i32 31>
-; CHECK-NEXT:    [[MUL:%.*]] = mul nsw <2 x i32> [[ASH]], <i32 42, i32 -3>
+; CHECK-NEXT:    [[ISNEG:%.*]] = icmp slt <2 x i32> [[X:%.*]], zeroinitializer
+; CHECK-NEXT:    [[MUL:%.*]] = select <2 x i1> [[ISNEG]], <2 x i32> <i32 -42, i32 3>, <2 x i32> zeroinitializer
 ; CHECK-NEXT:    ret <2 x i32> [[MUL]]
 ;
   %ash = ashr <2 x i32> %x, <i32 31, i32 31>
@@ -484,6 +484,8 @@ define <2 x i32> @signsplat_mul_vec(<2 x i32> %x) {
   ret <2 x i32> %mul
 }
 
+; negative test - wrong shift amount
+
 define i32 @not_signsplat_mul(i32 %x) {
 ; CHECK-LABEL: @not_signsplat_mul(
 ; CHECK-NEXT:    [[ASH:%.*]] = ashr i32 [[X:%.*]], 30
@@ -495,6 +497,8 @@ define i32 @not_signsplat_mul(i32 %x) {
   ret i32 %mul
 }
 
+; negative test - extra use
+
 define i32 @signsplat_mul_use(i32 %x) {
 ; CHECK-LABEL: @signsplat_mul_use(
 ; CHECK-NEXT:    [[ASH:%.*]] = ashr i32 [[X:%.*]], 31