return NewShift;
}
+// If we have some pattern that leaves only some low bits set, and then performs
+// left-shift of those bits, if none of the bits that are left after the final
+// shift are modified by the mask, we can omit the mask.
+//
+// There are many variants to this pattern:
+// a) (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt
+// All these patterns can be simplified to just:
+// x << ShiftShAmt
+// iff:
+// a) (MaskShAmt+ShiftShAmt) u>= bitwidth(x)
+static Instruction *
+dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
+ const SimplifyQuery &SQ) {
+ assert(OuterShift->getOpcode() == Instruction::BinaryOps::Shl &&
+ "The input must be 'shl'!");
+
+ Value *Masked = OuterShift->getOperand(0);
+ Value *ShiftShAmt = OuterShift->getOperand(1);
+
+ Value *MaskShAmt;
+
+ // ((1 << MaskShAmt) - 1)
+ auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes());
+
+ Value *X;
+ if (!match(Masked, m_c_And(MaskA, m_Value(X))))
+ return nullptr;
+
+ // Can we simplify (MaskShAmt+ShiftShAmt) ?
+ Value *SumOfShAmts =
+ SimplifyAddInst(MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false,
+ SQ.getWithInstruction(OuterShift));
+ if (!SumOfShAmts)
+ return nullptr; // Did not simplify.
+ // Is the total shift amount *not* smaller than the bit width?
+ // FIXME: could also rely on ConstantRange.
+ unsigned BitWidth = X->getType()->getScalarSizeInBits();
+ if (!match(SumOfShAmts, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_UGE,
+ APInt(BitWidth, BitWidth))))
+ return nullptr;
+ // All good, we can do this fold.
+
+ // No 'NUW'/'NSW'!
+ // We no longer know that we won't shift-out non-0 bits.
+ return BinaryOperator::Create(OuterShift->getOpcode(), X, ShiftShAmt);
+}
+
Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
assert(Op0->getType() == Op1->getType());
if (Instruction *V = commonShiftTransforms(I))
return V;
+ if (Instruction *V = dropRedundantMaskingOfLeftShiftInput(&I, SQ))
+ return V;
+
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Type *Ty = I.getType();
unsigned BitWidth = Ty->getScalarSizeInBits();
; CHECK-NEXT: call void @use32(i32 [[T1]])
; CHECK-NEXT: call void @use32(i32 [[T2]])
; CHECK-NEXT: call void @use32(i32 [[T3]])
-; CHECK-NEXT: [[T4:%.*]] = shl i32 [[T2]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = shl i32 [[X]], [[T3]]
; CHECK-NEXT: ret i32 [[T4]]
;
%t0 = shl i32 1, %nbits
; CHECK-NEXT: call void @use32(i32 [[T1]])
; CHECK-NEXT: call void @use32(i32 [[T2]])
; CHECK-NEXT: call void @use32(i32 [[T3]])
-; CHECK-NEXT: [[T4:%.*]] = shl i32 [[T2]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = shl i32 [[X]], [[T3]]
; CHECK-NEXT: ret i32 [[T4]]
;
%t0 = shl i32 1, %nbits
; CHECK-NEXT: call void @use32(i32 [[T2]])
; CHECK-NEXT: call void @use32(i32 [[T3]])
; CHECK-NEXT: call void @use32(i32 [[T4]])
-; CHECK-NEXT: [[T5:%.*]] = shl i32 [[T3]], [[T4]]
+; CHECK-NEXT: [[T5:%.*]] = shl i32 [[X]], [[T4]]
; CHECK-NEXT: ret i32 [[T5]]
;
%t0 = add i32 %nbits, 1
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]])
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T3]])
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T4]])
-; CHECK-NEXT: [[T5:%.*]] = shl <3 x i32> [[T3]], [[T4]]
+; CHECK-NEXT: [[T5:%.*]] = shl <3 x i32> [[X]], [[T4]]
; CHECK-NEXT: ret <3 x i32> [[T5]]
;
%t0 = add <3 x i32> %nbits, <i32 0, i32 0, i32 0>
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]])
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T3]])
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T4]])
-; CHECK-NEXT: [[T5:%.*]] = shl <3 x i32> [[T3]], [[T4]]
+; CHECK-NEXT: [[T5:%.*]] = shl <3 x i32> [[X]], [[T4]]
; CHECK-NEXT: ret <3 x i32> [[T5]]
;
%t0 = add <3 x i32> %nbits, <i32 -1, i32 0, i32 1>
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]])
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T3]])
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T4]])
-; CHECK-NEXT: [[T5:%.*]] = shl <3 x i32> [[T3]], [[T4]]
+; CHECK-NEXT: [[T5:%.*]] = shl <3 x i32> [[X]], [[T4]]
; CHECK-NEXT: ret <3 x i32> [[T5]]
;
%t0 = add <3 x i32> %nbits, <i32 0, i32 undef, i32 0>
; CHECK-NEXT: call void @use32(i32 [[T1]])
; CHECK-NEXT: call void @use32(i32 [[T2]])
; CHECK-NEXT: call void @use32(i32 [[T3]])
-; CHECK-NEXT: [[T4:%.*]] = shl i32 [[T2]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = shl i32 [[X]], [[T3]]
; CHECK-NEXT: ret i32 [[T4]]
;
%x = call i32 @gen32()
; CHECK-NEXT: call void @use32(i32 [[T3]])
; CHECK-NEXT: call void @use32(i32 [[T4]])
; CHECK-NEXT: call void @use32(i32 [[T5]])
-; CHECK-NEXT: [[T6:%.*]] = shl i32 [[T4]], [[T5]]
+; CHECK-NEXT: [[T6:%.*]] = shl i32 [[T1]], [[T5]]
; CHECK-NEXT: ret i32 [[T6]]
;
%t0 = shl i32 1, %nbits0
; CHECK-NEXT: call void @use32(i32 [[T1]])
; CHECK-NEXT: call void @use32(i32 [[T2]])
; CHECK-NEXT: call void @use32(i32 [[T3]])
-; CHECK-NEXT: [[T4:%.*]] = shl nuw i32 [[T2]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = shl i32 [[X]], [[T3]]
; CHECK-NEXT: ret i32 [[T4]]
;
%t0 = shl i32 1, %nbits
; CHECK-NEXT: call void @use32(i32 [[T1]])
; CHECK-NEXT: call void @use32(i32 [[T2]])
; CHECK-NEXT: call void @use32(i32 [[T3]])
-; CHECK-NEXT: [[T4:%.*]] = shl nsw i32 [[T2]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = shl i32 [[X]], [[T3]]
; CHECK-NEXT: ret i32 [[T4]]
;
%t0 = shl i32 1, %nbits
; CHECK-NEXT: call void @use32(i32 [[T1]])
; CHECK-NEXT: call void @use32(i32 [[T2]])
; CHECK-NEXT: call void @use32(i32 [[T3]])
-; CHECK-NEXT: [[T4:%.*]] = shl nuw nsw i32 [[T2]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = shl i32 [[X]], [[T3]]
; CHECK-NEXT: ret i32 [[T4]]
;
%t0 = shl i32 1, %nbits