From e3c9070f06d33907e77956fd16abca90bf5ef819 Mon Sep 17 00:00:00 2001 From: Zhaoshi Zheng Date: Thu, 23 Mar 2017 18:06:09 +0000 Subject: [PATCH] Model ashr(shl(x, n), m) as mul(x, 2^(n-m)) when n > m Given below case: %y = shl %x, n %z = ashr %y, m when n = m, SCEV models it as sext(trunc(x)). This patch tries to handle the case where n > m by using sext(mul(trunc(x), 2^(n-m)))) as the SCEV expression. llvm-svn: 298631 --- llvm/lib/Analysis/ScalarEvolution.cpp | 65 ++++++++++++------ llvm/test/Analysis/ScalarEvolution/sext-mul.ll | 89 +++++++++++++++++++++++++ llvm/test/Analysis/ScalarEvolution/sext-zero.ll | 39 +++++++++++ 3 files changed, 174 insertions(+), 19 deletions(-) create mode 100644 llvm/test/Analysis/ScalarEvolution/sext-mul.ll create mode 100644 llvm/test/Analysis/ScalarEvolution/sext-zero.ll diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index c820464..5863406 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -5356,28 +5356,55 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { break; case Instruction::AShr: - // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression. - if (ConstantInt *CI = dyn_cast(BO->RHS)) - if (Operator *L = dyn_cast(BO->LHS)) - if (L->getOpcode() == Instruction::Shl && - L->getOperand(1) == BO->RHS) { - uint64_t BitWidth = getTypeSizeInBits(BO->LHS->getType()); - - // If the shift count is not less than the bitwidth, the result of - // the shift is undefined. Don't try to analyze it, because the - // resolution chosen here may differ from the resolution chosen in - // other parts of the compiler. - if (CI->getValue().uge(BitWidth)) - break; + // AShr X, C, where C is a constant. + ConstantInt *CI = dyn_cast(BO->RHS); + if (!CI) + break; + + Type *OuterTy = BO->LHS->getType(); + uint64_t BitWidth = getTypeSizeInBits(OuterTy); + // If the shift count is not less than the bitwidth, the result of + // the shift is undefined. Don't try to analyze it, because the + // resolution chosen here may differ from the resolution chosen in + // other parts of the compiler. + if (CI->getValue().uge(BitWidth)) + break; - uint64_t Amt = BitWidth - CI->getZExtValue(); - if (Amt == BitWidth) - return getSCEV(L->getOperand(0)); // shift by zero --> noop + if (CI->isNullValue()) + return getSCEV(BO->LHS); // shift by zero --> noop + + uint64_t AShrAmt = CI->getZExtValue(); + Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt); + + Operator *L = dyn_cast(BO->LHS); + if (L && L->getOpcode() == Instruction::Shl) { + // X = Shl A, n + // Y = AShr X, m + // Both n and m are constant. + + const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0)); + if (L->getOperand(1) == BO->RHS) + // For a two-shift sext-inreg, i.e. n = m, + // use sext(trunc(x)) as the SCEV expression. + return getSignExtendExpr( + getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy); + + ConstantInt *ShlAmtCI = dyn_cast(L->getOperand(1)); + if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) { + uint64_t ShlAmt = ShlAmtCI->getZExtValue(); + if (ShlAmt > AShrAmt) { + // When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV + // expression. We already checked that ShlAmt < BitWidth, so + // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as + // ShlAmt - AShrAmt < Amt. + APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt, + ShlAmt - AShrAmt); return getSignExtendExpr( - getTruncateExpr(getSCEV(L->getOperand(0)), - IntegerType::get(getContext(), Amt)), - BO->LHS->getType()); + getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy), + getConstant(Mul)), OuterTy); } + } + } break; } } diff --git a/llvm/test/Analysis/ScalarEvolution/sext-mul.ll b/llvm/test/Analysis/ScalarEvolution/sext-mul.ll new file mode 100644 index 0000000..ca25d9e --- /dev/null +++ b/llvm/test/Analysis/ScalarEvolution/sext-mul.ll @@ -0,0 +1,89 @@ +; RUN: opt < %s -analyze -scalar-evolution | FileCheck %s + +; CHECK: %tmp9 = shl i64 %tmp8, 33 +; CHECK-NEXT: --> {{.*}} Exits: (-8589934592 + (8589934592 * (zext i32 %arg2 to i64))) +; CHECK: %tmp10 = ashr exact i64 %tmp9, 32 +; CHECK-NEXT: --> {{.*}} Exits: (sext i32 (-2 + (2 * %arg2)) to i64) +; CHECK: %tmp11 = getelementptr inbounds i32, i32* %arg, i64 %tmp10 +; CHECK-NEXT: --> {{.*}} Exits: ((4 * (sext i32 (-2 + (2 * %arg2)) to i64)) + %arg) +; CHECK: %tmp14 = or i64 %tmp10, 1 +; CHECK-NEXT: --> {{.*}} Exits: (1 + (sext i32 (-2 + (2 * %arg2)) to i64)) +; CHECK: %tmp15 = getelementptr inbounds i32, i32* %arg, i64 %tmp14 +; CHECK-NEXT: --> {{.*}} Exits: (4 + (4 * (sext i32 (-2 + (2 * %arg2)) to i64)) + %arg) +; CHECK:Loop %bb7: backedge-taken count is (-1 + (zext i32 %arg2 to i64)) +; CHECK-NEXT:Loop %bb7: max backedge-taken count is -1 +; CHECK-NEXT:Loop %bb7: Predicated backedge-taken count is (-1 + (zext i32 %arg2 to i64)) + +define void @foo(i32* nocapture %arg, i32 %arg1, i32 %arg2) { +bb: + %tmp = icmp sgt i32 %arg2, 0 + br i1 %tmp, label %bb3, label %bb6 + +bb3: ; preds = %bb + %tmp4 = zext i32 %arg2 to i64 + br label %bb7 + +bb5: ; preds = %bb7 + br label %bb6 + +bb6: ; preds = %bb5, %bb + ret void + +bb7: ; preds = %bb7, %bb3 + %tmp8 = phi i64 [ %tmp18, %bb7 ], [ 0, %bb3 ] + %tmp9 = shl i64 %tmp8, 33 + %tmp10 = ashr exact i64 %tmp9, 32 + %tmp11 = getelementptr inbounds i32, i32* %arg, i64 %tmp10 + %tmp12 = load i32, i32* %tmp11, align 4 + %tmp13 = sub nsw i32 %tmp12, %arg1 + store i32 %tmp13, i32* %tmp11, align 4 + %tmp14 = or i64 %tmp10, 1 + %tmp15 = getelementptr inbounds i32, i32* %arg, i64 %tmp14 + %tmp16 = load i32, i32* %tmp15, align 4 + %tmp17 = mul nsw i32 %tmp16, %arg1 + store i32 %tmp17, i32* %tmp15, align 4 + %tmp18 = add nuw nsw i64 %tmp8, 1 + %tmp19 = icmp eq i64 %tmp18, %tmp4 + br i1 %tmp19, label %bb5, label %bb7 +} + +; CHECK: %t10 = ashr exact i128 %t9, 1 +; CHECK-NEXT: --> {{.*}} Exits: (sext i127 (-633825300114114700748351602688 + (633825300114114700748351602688 * (zext i32 %arg5 to i127))) to i128) +; CHECK: %t14 = or i128 %t10, 1 +; CHECK-NEXT: --> {{.*}} Exits: (1 + (sext i127 (-633825300114114700748351602688 + (633825300114114700748351602688 * (zext i32 %arg5 to i127))) to i128)) +; CHECK: Loop %bb7: backedge-taken count is (-1 + (zext i32 %arg5 to i128)) +; CHECK-NEXT: Loop %bb7: max backedge-taken count is -1 +; CHECK-NEXT: Loop %bb7: Predicated backedge-taken count is (-1 + (zext i32 %arg5 to i128)) + +define void @goo(i32* nocapture %arg3, i32 %arg4, i32 %arg5) { +bb: + %t = icmp sgt i32 %arg5, 0 + br i1 %t, label %bb3, label %bb6 + +bb3: ; preds = %bb + %t4 = zext i32 %arg5 to i128 + br label %bb7 + +bb5: ; preds = %bb7 + br label %bb6 + +bb6: ; preds = %bb5, %bb + ret void + +bb7: ; preds = %bb7, %bb3 + %t8 = phi i128 [ %t18, %bb7 ], [ 0, %bb3 ] + %t9 = shl i128 %t8, 100 + %t10 = ashr exact i128 %t9, 1 + %t11 = getelementptr inbounds i32, i32* %arg3, i128 %t10 + %t12 = load i32, i32* %t11, align 4 + %t13 = sub nsw i32 %t12, %arg4 + store i32 %t13, i32* %t11, align 4 + %t14 = or i128 %t10, 1 + %t15 = getelementptr inbounds i32, i32* %arg3, i128 %t14 + %t16 = load i32, i32* %t15, align 4 + %t17 = mul nsw i32 %t16, %arg4 + store i32 %t17, i32* %t15, align 4 + %t18 = add nuw nsw i128 %t8, 1 + %t19 = icmp eq i128 %t18, %t4 + br i1 %t19, label %bb5, label %bb7 +} diff --git a/llvm/test/Analysis/ScalarEvolution/sext-zero.ll b/llvm/test/Analysis/ScalarEvolution/sext-zero.ll new file mode 100644 index 0000000..cac4263 --- /dev/null +++ b/llvm/test/Analysis/ScalarEvolution/sext-zero.ll @@ -0,0 +1,39 @@ +; RUN: opt < %s -analyze -scalar-evolution | FileCheck %s + +; CHECK: %tmp9 = shl i64 %tmp8, 33 +; CHECK-NEXT: --> {{.*}} Exits: (-8589934592 + (8589934592 * (zext i32 %arg2 to i64))) +; CHECK-NEXT: %tmp10 = ashr exact i64 %tmp9, 0 +; CHECK-NEXT: --> {{.*}} Exits: (-8589934592 + (8589934592 * (zext i32 %arg2 to i64))) + +define void @foo(i32* nocapture %arg, i32 %arg1, i32 %arg2) { +bb: + %tmp = icmp sgt i32 %arg2, 0 + br i1 %tmp, label %bb3, label %bb6 + +bb3: ; preds = %bb + %tmp4 = zext i32 %arg2 to i64 + br label %bb7 + +bb5: ; preds = %bb7 + br label %bb6 + +bb6: ; preds = %bb5, %bb + ret void + +bb7: ; preds = %bb7, %bb3 + %tmp8 = phi i64 [ %tmp18, %bb7 ], [ 0, %bb3 ] + %tmp9 = shl i64 %tmp8, 33 + %tmp10 = ashr exact i64 %tmp9, 0 + %tmp11 = getelementptr inbounds i32, i32* %arg, i64 %tmp10 + %tmp12 = load i32, i32* %tmp11, align 4 + %tmp13 = sub nsw i32 %tmp12, %arg1 + store i32 %tmp13, i32* %tmp11, align 4 + %tmp14 = or i64 %tmp10, 1 + %tmp15 = getelementptr inbounds i32, i32* %arg, i64 %tmp14 + %tmp16 = load i32, i32* %tmp15, align 4 + %tmp17 = mul nsw i32 %tmp16, %arg1 + store i32 %tmp17, i32* %tmp15, align 4 + %tmp18 = add nuw nsw i64 %tmp8, 1 + %tmp19 = icmp eq i64 %tmp18, %tmp4 + br i1 %tmp19, label %bb5, label %bb7 +} -- 2.7.4