[InstCombine] `sext(trunc(x)) --> sext(x)` iff trunc is NSW (PR49543)
authorRoman Lebedev <lebedev.ri@gmail.com>
Tue, 20 Apr 2021 19:16:11 +0000 (22:16 +0300)
committerRoman Lebedev <lebedev.ri@gmail.com>
Tue, 20 Apr 2021 21:31:45 +0000 (00:31 +0300)
If we can tell that trunc only chops off sign bits, and not all of them,
then we can simply sign-extend the trunc's source.

llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
llvm/test/Transforms/InstCombine/sext-of-trunc-nsw.ll

index 21ca15c..68f6d0e 100644 (file)
@@ -1493,12 +1493,22 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) {
   // If the input is a trunc from the destination type, then turn sext(trunc(x))
   // into shifts.
   Value *X;
-  if (match(Src, m_OneUse(m_Trunc(m_Value(X)))) && X->getType() == DestTy) {
-    // sext(trunc(X)) --> ashr(shl(X, C), C)
+  if (match(Src, m_Trunc(m_Value(X)))) {
     unsigned SrcBitSize = SrcTy->getScalarSizeInBits();
     unsigned DestBitSize = DestTy->getScalarSizeInBits();
-    Constant *ShAmt = ConstantInt::get(DestTy, DestBitSize - SrcBitSize);
-    return BinaryOperator::CreateAShr(Builder.CreateShl(X, ShAmt), ShAmt);
+    unsigned XBitSize = X->getType()->getScalarSizeInBits();
+
+    // Iff X had more sign bits than the number of bits that were chopped off
+    // by the truncation, we can directly sign-extend the X.
+    unsigned XNumSignBits = ComputeNumSignBits(X, 0, &CI);
+    if (XNumSignBits > (XBitSize - SrcBitSize))
+      return CastInst::Create(Instruction::SExt, X, DestTy);
+
+    if (Src->hasOneUse() && X->getType() == DestTy) {
+      // sext(trunc(X)) --> ashr(shl(X, C), C)
+      Constant *ShAmt = ConstantInt::get(DestTy, DestBitSize - SrcBitSize);
+      return BinaryOperator::CreateAShr(Builder.CreateShl(X, ShAmt), ShAmt);
+    }
   }
 
   if (ICmpInst *ICI = dyn_cast<ICmpInst>(Src))
index 10098d9..e5912ee 100644 (file)
@@ -13,8 +13,7 @@ define i16 @t0(i8 %x) {
 ; CHECK-LABEL: @t0(
 ; CHECK-NEXT:    [[A:%.*]] = ashr i8 [[X:%.*]], 5
 ; CHECK-NEXT:    call void @use8(i8 [[A]])
-; CHECK-NEXT:    [[B:%.*]] = trunc i8 [[A]] to i4
-; CHECK-NEXT:    [[C:%.*]] = sext i4 [[B]] to i16
+; CHECK-NEXT:    [[C:%.*]] = sext i8 [[A]] to i16
 ; CHECK-NEXT:    ret i16 [[C]]
 ;
   %a = ashr i8 %x, 5
@@ -28,8 +27,7 @@ define i16 @t1(i8 %x) {
 ; CHECK-LABEL: @t1(
 ; CHECK-NEXT:    [[A:%.*]] = ashr i8 [[X:%.*]], 4
 ; CHECK-NEXT:    call void @use8(i8 [[A]])
-; CHECK-NEXT:    [[B:%.*]] = trunc i8 [[A]] to i4
-; CHECK-NEXT:    [[C:%.*]] = sext i4 [[B]] to i16
+; CHECK-NEXT:    [[C:%.*]] = sext i8 [[A]] to i16
 ; CHECK-NEXT:    ret i16 [[C]]
 ;
   %a = ashr i8 %x, 4
@@ -59,8 +57,7 @@ define <2 x i16> @t3_vec(<2 x i8> %x) {
 ; CHECK-LABEL: @t3_vec(
 ; CHECK-NEXT:    [[A:%.*]] = ashr <2 x i8> [[X:%.*]], <i8 4, i8 4>
 ; CHECK-NEXT:    call void @usevec(<2 x i8> [[A]])
-; CHECK-NEXT:    [[B:%.*]] = trunc <2 x i8> [[A]] to <2 x i4>
-; CHECK-NEXT:    [[C:%.*]] = sext <2 x i4> [[B]] to <2 x i16>
+; CHECK-NEXT:    [[C:%.*]] = sext <2 x i8> [[A]] to <2 x i16>
 ; CHECK-NEXT:    ret <2 x i16> [[C]]
 ;
   %a = ashr <2 x i8> %x, <i8 4, i8 4>
@@ -91,7 +88,7 @@ define i16 @t5_extrause(i8 %x) {
 ; CHECK-NEXT:    call void @use8(i8 [[A]])
 ; CHECK-NEXT:    [[B:%.*]] = trunc i8 [[A]] to i4
 ; CHECK-NEXT:    call void @use4(i4 [[B]])
-; CHECK-NEXT:    [[C:%.*]] = sext i4 [[B]] to i16
+; CHECK-NEXT:    [[C:%.*]] = sext i8 [[A]] to i16
 ; CHECK-NEXT:    ret i16 [[C]]
 ;
   %a = ashr i8 %x, 5