[InstCombine] canonicalize rotate patterns with cmp/select
authorSanjay Patel <spatel@rotateright.com>
Tue, 13 Nov 2018 22:47:24 +0000 (22:47 +0000)
committerSanjay Patel <spatel@rotateright.com>
Tue, 13 Nov 2018 22:47:24 +0000 (22:47 +0000)
The cmp+branch variant of this pattern is shown in:
https://bugs.llvm.org/show_bug.cgi?id=34924
...and as discussed there, we probably can't transform
that without a rotate intrinsic. We do have that now
via funnel shift, but we're not quite ready to
canonicalize IR to that form yet. The case with 'select'
should already be transformed though, so that's this patch.

The sequence with negation followed by masking is what we
use in the backend and partly in clang (though that part
should be updated).

https://rise4fun.com/Alive/TplC
  %cmp = icmp eq i32 %shamt, 0
  %sub = sub i32 32, %shamt
  %shr = lshr i32 %x, %shamt
  %shl = shl i32 %x, %sub
  %or = or i32 %shr, %shl
  %r = select i1 %cmp, i32 %x, i32 %or
  =>
  %neg = sub i32 0, %shamt
  %masked = and i32 %shamt, 31
  %maskedneg = and i32 %neg, 31
  %shl2 = lshr i32 %x, %masked
  %shr2 = shl i32 %x, %maskedneg
  %r = or i32 %shl2, %shr2

llvm-svn: 346807

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
llvm/test/Transforms/InstCombine/rotate.ll

index 88a72bb..26d0b52 100644 (file)
@@ -1546,6 +1546,66 @@ static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS,
   return SelectInst::Create(CmpABC, MinMaxOp, ThirdOp);
 }
 
+/// Try to reduce a rotate pattern that includes a compare and select into a
+/// sequence of ALU ops only. Example:
+/// rotl32(a, b) --> (b == 0 ? a : ((a >> (32 - b)) | (a << b)))
+///              --> (a >> (-b & 31)) | (a << (b & 31))
+static Instruction *foldSelectRotate(SelectInst &Sel,
+                                     InstCombiner::BuilderTy &Builder) {
+  // The false value of the select must be a rotate of the true value.
+  Value *Or0, *Or1;
+  if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1)))))
+    return nullptr;
+
+  Value *TVal = Sel.getTrueValue();
+  Value *SA0, *SA1;
+  if (!match(Or0, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA0)))) ||
+      !match(Or1, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA1)))))
+    return nullptr;
+
+  auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode();
+  auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode();
+  if (ShiftOpcode0 == ShiftOpcode1)
+    return nullptr;
+
+  // We have one of these patterns so far:
+  // select ?, TVal, (or (lshr TVal, SA0), (shl TVal, SA1))
+  // select ?, TVal, (or (shl TVal, SA0), (lshr TVal, SA1))
+  // This must be a power-of-2 rotate for a bitmasking transform to be valid.
+  unsigned Width = Sel.getType()->getScalarSizeInBits();
+  if (!isPowerOf2_32(Width))
+    return nullptr;
+
+  // Check the shift amounts to see if they are an opposite pair.
+  Value *ShAmt;
+  if (match(SA1, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(SA0)))))
+    ShAmt = SA0;
+  else if (match(SA0, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(SA1)))))
+    ShAmt = SA1;
+  else
+    return nullptr;
+
+  // Finally, see if the select is filtering out a shift-by-zero.
+  Value *Cond = Sel.getCondition();
+  ICmpInst::Predicate Pred;
+  if (!match(Cond, m_OneUse(m_ICmp(Pred, m_Specific(ShAmt), m_ZeroInt()))) ||
+      Pred != ICmpInst::ICMP_EQ)
+    return nullptr;
+
+  // This is a rotate that avoids shift-by-bitwidth UB in a suboptimal way.
+  // Convert to safely bitmasked shifts.
+  // TODO: When we can canonicalize to funnel shift intrinsics without risk of
+  // performance regressions, replace this sequence with that call.
+  Value *NegShAmt = Builder.CreateNeg(ShAmt);
+  Value *MaskedShAmt = Builder.CreateAnd(ShAmt, Width - 1);
+  Value *MaskedNegShAmt = Builder.CreateAnd(NegShAmt, Width - 1);
+  Value *NewSA0 = ShAmt == SA0 ? MaskedShAmt : MaskedNegShAmt;
+  Value *NewSA1 = ShAmt == SA1 ? MaskedShAmt : MaskedNegShAmt;
+  Value *NewSh0 = Builder.CreateBinOp(ShiftOpcode0, TVal, NewSA0);
+  Value *NewSh1 = Builder.CreateBinOp(ShiftOpcode1, TVal, NewSA1);
+  return BinaryOperator::CreateOr(NewSh0, NewSh1);
+}
+
 Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
   Value *CondVal = SI.getCondition();
   Value *TrueVal = SI.getTrueValue();
@@ -2010,5 +2070,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
   if (Instruction *Select = foldSelectBinOpIdentity(SI, TLI))
     return Select;
 
+  if (Instruction *Rot = foldSelectRotate(SI, Builder))
+    return Rot;
+
   return nullptr;
 }
index 4401539..6150063 100644 (file)
@@ -309,16 +309,16 @@ define i8 @rotateleft_8_neg_mask_wide_amount_commute(i8 %v, i32 %shamt) {
   ret i8 %ret
 }
 
-; TODO: Convert select pattern to masked shift that ends in 'or'.
+; Convert select pattern to masked shift that ends in 'or'.
 
 define i32 @rotr_select(i32 %x, i32 %shamt) {
 ; CHECK-LABEL: @rotr_select(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[SHAMT:%.*]], 0
-; CHECK-NEXT:    [[SUB:%.*]] = sub i32 32, [[SHAMT]]
-; CHECK-NEXT:    [[SHR:%.*]] = lshr i32 [[X:%.*]], [[SHAMT]]
-; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[X]], [[SUB]]
-; CHECK-NEXT:    [[OR:%.*]] = or i32 [[SHR]], [[SHL]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[CMP]], i32 [[X]], i32 [[OR]]
+; CHECK-NEXT:    [[TMP1:%.*]] = sub i32 0, [[SHAMT:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[SHAMT]], 31
+; CHECK-NEXT:    [[TMP3:%.*]] = and i32 [[TMP1]], 31
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i32 [[X:%.*]], [[TMP2]]
+; CHECK-NEXT:    [[TMP5:%.*]] = shl i32 [[X]], [[TMP3]]
+; CHECK-NEXT:    [[R:%.*]] = or i32 [[TMP4]], [[TMP5]]
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %cmp = icmp eq i32 %shamt, 0
@@ -330,16 +330,16 @@ define i32 @rotr_select(i32 %x, i32 %shamt) {
   ret i32 %r
 }
 
-; TODO: Convert select pattern to masked shift that ends in 'or'.
+; Convert select pattern to masked shift that ends in 'or'.
 
 define i8 @rotr_select_commute(i8 %x, i8 %shamt) {
 ; CHECK-LABEL: @rotr_select_commute(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[SHAMT:%.*]], 0
-; CHECK-NEXT:    [[SUB:%.*]] = sub i8 8, [[SHAMT]]
-; CHECK-NEXT:    [[SHR:%.*]] = lshr i8 [[X:%.*]], [[SHAMT]]
-; CHECK-NEXT:    [[SHL:%.*]] = shl i8 [[X]], [[SUB]]
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[SHL]], [[SHR]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[CMP]], i8 [[X]], i8 [[OR]]
+; CHECK-NEXT:    [[TMP1:%.*]] = sub i8 0, [[SHAMT:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and i8 [[SHAMT]], 7
+; CHECK-NEXT:    [[TMP3:%.*]] = and i8 [[TMP1]], 7
+; CHECK-NEXT:    [[TMP4:%.*]] = shl i8 [[X:%.*]], [[TMP3]]
+; CHECK-NEXT:    [[TMP5:%.*]] = lshr i8 [[X]], [[TMP2]]
+; CHECK-NEXT:    [[R:%.*]] = or i8 [[TMP4]], [[TMP5]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %cmp = icmp eq i8 %shamt, 0
@@ -351,16 +351,16 @@ define i8 @rotr_select_commute(i8 %x, i8 %shamt) {
   ret i8 %r
 }
 
-; TODO: Convert select pattern to masked shift that ends in 'or'.
+; Convert select pattern to masked shift that ends in 'or'.
 
 define i16 @rotl_select(i16 %x, i16 %shamt) {
 ; CHECK-LABEL: @rotl_select(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i16 [[SHAMT:%.*]], 0
-; CHECK-NEXT:    [[SUB:%.*]] = sub i16 16, [[SHAMT]]
-; CHECK-NEXT:    [[SHR:%.*]] = lshr i16 [[X:%.*]], [[SUB]]
-; CHECK-NEXT:    [[SHL:%.*]] = shl i16 [[X]], [[SHAMT]]
-; CHECK-NEXT:    [[OR:%.*]] = or i16 [[SHR]], [[SHL]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[CMP]], i16 [[X]], i16 [[OR]]
+; CHECK-NEXT:    [[TMP1:%.*]] = sub i16 0, [[SHAMT:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and i16 [[SHAMT]], 15
+; CHECK-NEXT:    [[TMP3:%.*]] = and i16 [[TMP1]], 15
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i16 [[X:%.*]], [[TMP3]]
+; CHECK-NEXT:    [[TMP5:%.*]] = shl i16 [[X]], [[TMP2]]
+; CHECK-NEXT:    [[R:%.*]] = or i16 [[TMP4]], [[TMP5]]
 ; CHECK-NEXT:    ret i16 [[R]]
 ;
   %cmp = icmp eq i16 %shamt, 0
@@ -372,24 +372,45 @@ define i16 @rotl_select(i16 %x, i16 %shamt) {
   ret i16 %r
 }
 
-; TODO: Convert select pattern to masked shift that ends in 'or'.
+; Convert select pattern to masked shift that ends in 'or'.
 
-define i64 @rotl_select_commute(i64 %x, i64 %shamt) {
+define <2 x i64> @rotl_select_commute(<2 x i64> %x, <2 x i64> %shamt) {
 ; CHECK-LABEL: @rotl_select_commute(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i64 [[SHAMT:%.*]], 0
-; CHECK-NEXT:    [[SUB:%.*]] = sub i64 64, [[SHAMT]]
-; CHECK-NEXT:    [[SHR:%.*]] = lshr i64 [[X:%.*]], [[SUB]]
-; CHECK-NEXT:    [[SHL:%.*]] = shl i64 [[X]], [[SHAMT]]
-; CHECK-NEXT:    [[OR:%.*]] = or i64 [[SHL]], [[SHR]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[CMP]], i64 [[X]], i64 [[OR]]
-; CHECK-NEXT:    ret i64 [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = sub <2 x i64> zeroinitializer, [[SHAMT:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and <2 x i64> [[SHAMT]], <i64 63, i64 63>
+; CHECK-NEXT:    [[TMP3:%.*]] = and <2 x i64> [[TMP1]], <i64 63, i64 63>
+; CHECK-NEXT:    [[TMP4:%.*]] = shl <2 x i64> [[X:%.*]], [[TMP2]]
+; CHECK-NEXT:    [[TMP5:%.*]] = lshr <2 x i64> [[X]], [[TMP3]]
+; CHECK-NEXT:    [[R:%.*]] = or <2 x i64> [[TMP4]], [[TMP5]]
+; CHECK-NEXT:    ret <2 x i64> [[R]]
+;
+  %cmp = icmp eq <2 x i64> %shamt, zeroinitializer
+  %sub = sub <2 x i64> <i64 64, i64 64>, %shamt
+  %shr = lshr <2 x i64> %x, %sub
+  %shl = shl <2 x i64> %x, %shamt
+  %or = or <2 x i64> %shl, %shr
+  %r = select <2 x i1> %cmp, <2 x i64> %x, <2 x i64> %or
+  ret <2 x i64> %r
+}
+
+; Negative test - the transform is only valid with power-of-2 types.
+
+define i24 @rotl_select_weird_type(i24 %x, i24 %shamt) {
+; CHECK-LABEL: @rotl_select_weird_type(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i24 [[SHAMT:%.*]], 0
+; CHECK-NEXT:    [[SUB:%.*]] = sub i24 24, [[SHAMT]]
+; CHECK-NEXT:    [[SHR:%.*]] = lshr i24 [[X:%.*]], [[SUB]]
+; CHECK-NEXT:    [[SHL:%.*]] = shl i24 [[X]], [[SHAMT]]
+; CHECK-NEXT:    [[OR:%.*]] = or i24 [[SHL]], [[SHR]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[CMP]], i24 [[X]], i24 [[OR]]
+; CHECK-NEXT:    ret i24 [[R]]
 ;
-  %cmp = icmp eq i64 %shamt, 0
-  %sub = sub i64 64, %shamt
-  %shr = lshr i64 %x, %sub
-  %shl = shl i64 %x, %shamt
-  %or = or i64 %shl, %shr
-  %r = select i1 %cmp, i64 %x, i64 %or
-  ret i64 %r
+  %cmp = icmp eq i24 %shamt, 0
+  %sub = sub i24 24, %shamt
+  %shr = lshr i24 %x, %sub
+  %shl = shl i24 %x, %shamt
+  %or = or i24 %shl, %shr
+  %r = select i1 %cmp, i24 %x, i24 %or
+  ret i24 %r
 }