From 664b7a4cd51d9273888e79688f64cc8bbcbdbe25 Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Mon, 19 Jun 2023 12:27:46 +0200 Subject: [PATCH] [SCCP] Fix conversion of range to constant for vectors (PR63380) The ConstantRange specifies the range of the scalar elements in the vector. When converting into a Constant, we need to create a vector splat with the correct type. For that purpose, pass in the expected type for the constant. Fixes https://github.com/llvm/llvm-project/issues/63380. --- llvm/include/llvm/Transforms/Utils/SCCPSolver.h | 2 +- llvm/lib/Transforms/Utils/SCCPSolver.cpp | 64 ++++++++++++++----------- llvm/test/Transforms/SCCP/intrinsics.ll | 14 ++++++ 3 files changed, 52 insertions(+), 28 deletions(-) diff --git a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h index 3754b51..7930d95 100644 --- a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h +++ b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h @@ -160,7 +160,7 @@ public: /// Helper to return a Constant if \p LV is either a constant or a constant /// range with a single element. - Constant *getConstant(const ValueLatticeElement &LV) const; + Constant *getConstant(const ValueLatticeElement &LV, Type *Ty) const; /// Return either a Constant or nullptr for a given Value. Constant *getConstantOrNull(Value *V) const; diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp index 902651a..24d1a46 100644 --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -394,8 +394,8 @@ class SCCPInstVisitor : public InstVisitor { LLVMContext &Ctx; private: - ConstantInt *getConstantInt(const ValueLatticeElement &IV) const { - return dyn_cast_or_null(getConstant(IV)); + ConstantInt *getConstantInt(const ValueLatticeElement &IV, Type *Ty) const { + return dyn_cast_or_null(getConstant(IV, Ty)); } // pushToWorkList - Helper for markConstant/markOverdefined @@ -778,7 +778,7 @@ public: bool isStructLatticeConstant(Function *F, StructType *STy); - Constant *getConstant(const ValueLatticeElement &LV) const; + Constant *getConstant(const ValueLatticeElement &LV, Type *Ty) const; Constant *getConstantOrNull(Value *V) const; @@ -881,14 +881,18 @@ bool SCCPInstVisitor::isStructLatticeConstant(Function *F, StructType *STy) { return true; } -Constant *SCCPInstVisitor::getConstant(const ValueLatticeElement &LV) const { - if (LV.isConstant()) - return LV.getConstant(); +Constant *SCCPInstVisitor::getConstant(const ValueLatticeElement &LV, + Type *Ty) const { + if (LV.isConstant()) { + Constant *C = LV.getConstant(); + assert(C->getType() == Ty && "Type mismatch"); + return C; + } if (LV.isConstantRange()) { const auto &CR = LV.getConstantRange(); if (CR.getSingleElement()) - return ConstantInt::get(Ctx, *CR.getSingleElement()); + return ConstantInt::get(Ty, *CR.getSingleElement()); } return nullptr; } @@ -904,7 +908,7 @@ Constant *SCCPInstVisitor::getConstantOrNull(Value *V) const { for (unsigned I = 0, E = ST->getNumElements(); I != E; ++I) { ValueLatticeElement LV = LVs[I]; ConstVals.push_back(SCCPSolver::isConstant(LV) - ? getConstant(LV) + ? getConstant(LV, ST->getElementType(I)) : UndefValue::get(ST->getElementType(I))); } Const = ConstantStruct::get(ST, ConstVals); @@ -912,7 +916,7 @@ Constant *SCCPInstVisitor::getConstantOrNull(Value *V) const { const ValueLatticeElement &LV = getLatticeValueFor(V); if (SCCPSolver::isOverdefined(LV)) return nullptr; - Const = SCCPSolver::isConstant(LV) ? getConstant(LV) + Const = SCCPSolver::isConstant(LV) ? getConstant(LV, V->getType()) : UndefValue::get(V->getType()); } assert(Const && "Constant is nullptr here!"); @@ -1007,7 +1011,7 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI, } ValueLatticeElement BCValue = getValueState(BI->getCondition()); - ConstantInt *CI = getConstantInt(BCValue); + ConstantInt *CI = getConstantInt(BCValue, BI->getCondition()->getType()); if (!CI) { // Overdefined condition variables, and branches on unfoldable constant // conditions, mean the branch could go either way. @@ -1033,7 +1037,8 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI, return; } const ValueLatticeElement &SCValue = getValueState(SI->getCondition()); - if (ConstantInt *CI = getConstantInt(SCValue)) { + if (ConstantInt *CI = + getConstantInt(SCValue, SI->getCondition()->getType())) { Succs[SI->findCaseValue(CI)->getSuccessorIndex()] = true; return; } @@ -1064,7 +1069,8 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI, if (auto *IBR = dyn_cast(&TI)) { // Casts are folded by visitCastInst. ValueLatticeElement IBRValue = getValueState(IBR->getAddress()); - BlockAddress *Addr = dyn_cast_or_null(getConstant(IBRValue)); + BlockAddress *Addr = dyn_cast_or_null( + getConstant(IBRValue, IBR->getAddress()->getType())); if (!Addr) { // Overdefined or unknown condition? // All destinations are executable! if (!IBRValue.isUnknownOrUndef()) @@ -1219,7 +1225,7 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) { if (OpSt.isUnknownOrUndef()) return; - if (Constant *OpC = getConstant(OpSt)) { + if (Constant *OpC = getConstant(OpSt, I.getOperand(0)->getType())) { // Fold the constant as we build. Constant *C = ConstantFoldCastOperand(I.getOpcode(), OpC, I.getType(), DL); markConstant(&I, C); @@ -1354,7 +1360,8 @@ void SCCPInstVisitor::visitSelectInst(SelectInst &I) { if (CondValue.isUnknownOrUndef()) return; - if (ConstantInt *CondCB = getConstantInt(CondValue)) { + if (ConstantInt *CondCB = + getConstantInt(CondValue, I.getCondition()->getType())) { Value *OpVal = CondCB->isZero() ? I.getFalseValue() : I.getTrueValue(); mergeInValue(&I, getValueState(OpVal)); return; @@ -1387,8 +1394,8 @@ void SCCPInstVisitor::visitUnaryOperator(Instruction &I) { return; if (SCCPSolver::isConstant(V0State)) - if (Constant *C = ConstantFoldUnaryOpOperand(I.getOpcode(), - getConstant(V0State), DL)) + if (Constant *C = ConstantFoldUnaryOpOperand( + I.getOpcode(), getConstant(V0State, I.getType()), DL)) return (void)markConstant(IV, &I, C); markOverdefined(&I); @@ -1412,8 +1419,8 @@ void SCCPInstVisitor::visitFreezeInst(FreezeInst &I) { return; if (SCCPSolver::isConstant(V0State) && - isGuaranteedNotToBeUndefOrPoison(getConstant(V0State))) - return (void)markConstant(IV, &I, getConstant(V0State)); + isGuaranteedNotToBeUndefOrPoison(getConstant(V0State, I.getType()))) + return (void)markConstant(IV, &I, getConstant(V0State, I.getType())); markOverdefined(&I); } @@ -1437,10 +1444,12 @@ void SCCPInstVisitor::visitBinaryOperator(Instruction &I) { // If either of the operands is a constant, try to fold it to a constant. // TODO: Use information from notconstant better. if ((V1State.isConstant() || V2State.isConstant())) { - Value *V1 = SCCPSolver::isConstant(V1State) ? getConstant(V1State) - : I.getOperand(0); - Value *V2 = SCCPSolver::isConstant(V2State) ? getConstant(V2State) - : I.getOperand(1); + Value *V1 = SCCPSolver::isConstant(V1State) + ? getConstant(V1State, I.getOperand(0)->getType()) + : I.getOperand(0); + Value *V2 = SCCPSolver::isConstant(V2State) + ? getConstant(V2State, I.getOperand(1)->getType()) + : I.getOperand(1); Value *R = simplifyBinOp(I.getOpcode(), V1, V2, SimplifyQuery(DL)); auto *C = dyn_cast_or_null(R); if (C) { @@ -1518,7 +1527,7 @@ void SCCPInstVisitor::visitGetElementPtrInst(GetElementPtrInst &I) { if (SCCPSolver::isOverdefined(State)) return (void)markOverdefined(&I); - if (Constant *C = getConstant(State)) { + if (Constant *C = getConstant(State, I.getOperand(i)->getType())) { Operands.push_back(C); continue; } @@ -1584,7 +1593,7 @@ void SCCPInstVisitor::visitLoadInst(LoadInst &I) { ValueLatticeElement &IV = ValueState[&I]; if (SCCPSolver::isConstant(PtrVal)) { - Constant *Ptr = getConstant(PtrVal); + Constant *Ptr = getConstant(PtrVal, I.getOperand(0)->getType()); // load null is undefined. if (isa(Ptr)) { @@ -1647,7 +1656,7 @@ void SCCPInstVisitor::handleCallOverdefined(CallBase &CB) { if (SCCPSolver::isOverdefined(State)) return (void)markOverdefined(&CB); assert(SCCPSolver::isConstant(State) && "Unknown state!"); - Operands.push_back(getConstant(State)); + Operands.push_back(getConstant(State, A->getType())); } if (SCCPSolver::isOverdefined(getValueState(&CB))) @@ -2067,8 +2076,9 @@ bool SCCPSolver::isStructLatticeConstant(Function *F, StructType *STy) { return Visitor->isStructLatticeConstant(F, STy); } -Constant *SCCPSolver::getConstant(const ValueLatticeElement &LV) const { - return Visitor->getConstant(LV); +Constant *SCCPSolver::getConstant(const ValueLatticeElement &LV, + Type *Ty) const { + return Visitor->getConstant(LV, Ty); } Constant *SCCPSolver::getConstantOrNull(Value *V) const { diff --git a/llvm/test/Transforms/SCCP/intrinsics.ll b/llvm/test/Transforms/SCCP/intrinsics.ll index 3fc7637..5edb317 100644 --- a/llvm/test/Transforms/SCCP/intrinsics.ll +++ b/llvm/test/Transforms/SCCP/intrinsics.ll @@ -122,3 +122,17 @@ exit: %p_umax = call i8 @llvm.umax.i8(i8 %p, i8 1) ret i8 %p_umax } + +define <4 x i32> @pr63380(<4 x i32> %input) { +; CHECK-LABEL: @pr63380( +; CHECK-NEXT: [[CTLZ_1:%.*]] = call <4 x i32> @llvm.ctlz.v4i32(<4 x i32> [[INPUT:%.*]], i1 false) +; CHECK-NEXT: [[CTLZ_2:%.*]] = call <4 x i32> @llvm.ctlz.v4i32(<4 x i32> [[CTLZ_1]], i1 true) +; CHECK-NEXT: ret <4 x i32> +; + %ctlz.1 = call <4 x i32> @llvm.ctlz.v4i32(<4 x i32> %input, i1 false) + %ctlz.2 = call <4 x i32> @llvm.ctlz.v4i32(<4 x i32> %ctlz.1, i1 true) + %ctlz.3 = call <4 x i32> @llvm.ctlz.v4i32(<4 x i32> %ctlz.2, i1 true) + ret <4 x i32> %ctlz.3 +} + +declare <4 x i32> @llvm.ctlz.v4i32(<4 x i32>, i1 immarg) -- 2.7.4