From d4ebfd84008b309bda5a48f758d9f08cbb053552 Mon Sep 17 00:00:00 2001 From: Owen Anderson Date: Wed, 6 Feb 2013 22:43:31 +0000 Subject: [PATCH] Signficantly generalize our ability to constant fold floating point intrinsics, including ones on half types. llvm-svn: 174555 --- llvm/lib/Analysis/ConstantFolding.cpp | 102 ++++++++++++++++++++++++++++----- llvm/test/Transforms/ConstProp/half.ll | 42 ++++++++++++++ 2 files changed, 130 insertions(+), 14 deletions(-) create mode 100644 llvm/test/Transforms/ConstProp/half.ll diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index 91424b2..e499c73 100644 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -289,6 +289,10 @@ static bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset, C = FoldBitCast(C, Type::getInt32Ty(C->getContext()), TD); return ReadDataFromGlobal(C, ByteOffset, CurPtr, BytesLeft, TD); } + if (CFP->getType()->isHalfTy()){ + C = FoldBitCast(C, Type::getInt16Ty(C->getContext()), TD); + return ReadDataFromGlobal(C, ByteOffset, CurPtr, BytesLeft, TD); + } return false; } @@ -381,7 +385,9 @@ static Constant *FoldReinterpretLoadFromConstPtr(Constant *C, // that address spaces don't matter here since we're not going to result in // an actual new load. Type *MapTy; - if (LoadTy->isFloatTy()) + if (LoadTy->isHalfTy()) + MapTy = Type::getInt16PtrTy(C->getContext()); + else if (LoadTy->isFloatTy()) MapTy = Type::getInt32PtrTy(C->getContext()); else if (LoadTy->isDoubleTy()) MapTy = Type::getInt64PtrTy(C->getContext()); @@ -1089,6 +1095,13 @@ Constant *llvm::ConstantFoldLoadThroughGEPIndices(Constant *C, bool llvm::canConstantFoldCallTo(const Function *F) { switch (F->getIntrinsicID()) { + case Intrinsic::fabs: + case Intrinsic::log: + case Intrinsic::log2: + case Intrinsic::log10: + case Intrinsic::exp: + case Intrinsic::exp2: + case Intrinsic::floor: case Intrinsic::sqrt: case Intrinsic::pow: case Intrinsic::powi: @@ -1156,11 +1169,17 @@ static Constant *ConstantFoldFP(double (*NativeFP)(double), double V, return 0; } + if (Ty->isHalfTy()) { + APFloat APF(V); + bool unused; + APF.convert(APFloat::IEEEhalf, APFloat::rmNearestTiesToEven, &unused); + return ConstantFP::get(Ty->getContext(), APF); + } if (Ty->isFloatTy()) return ConstantFP::get(Ty->getContext(), APFloat((float)V)); if (Ty->isDoubleTy()) return ConstantFP::get(Ty->getContext(), APFloat(V)); - llvm_unreachable("Can only constant fold float/double"); + llvm_unreachable("Can only constant fold half/float/double"); } static Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double), @@ -1172,11 +1191,17 @@ static Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double), return 0; } + if (Ty->isHalfTy()) { + APFloat APF(V); + bool unused; + APF.convert(APFloat::IEEEhalf, APFloat::rmNearestTiesToEven, &unused); + return ConstantFP::get(Ty->getContext(), APF); + } if (Ty->isFloatTy()) return ConstantFP::get(Ty->getContext(), APFloat((float)V)); if (Ty->isDoubleTy()) return ConstantFP::get(Ty->getContext(), APFloat(V)); - llvm_unreachable("Can only constant fold float/double"); + llvm_unreachable("Can only constant fold half/float/double"); } /// ConstantFoldConvertToInt - Attempt to an SSE floating point to integer @@ -1228,7 +1253,7 @@ llvm::ConstantFoldCall(Function *F, ArrayRef Operands, if (!TLI) return 0; - if (!Ty->isFloatTy() && !Ty->isDoubleTy()) + if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy()) return 0; /// We only fold functions with finite arguments. Folding NaN and inf is @@ -1241,8 +1266,36 @@ llvm::ConstantFoldCall(Function *F, ArrayRef Operands, /// the host native double versions. Float versions are not called /// directly but for all these it is true (float)(f((double)arg)) == /// f(arg). Long double not supported yet. - double V = Ty->isFloatTy() ? (double)Op->getValueAPF().convertToFloat() : - Op->getValueAPF().convertToDouble(); + double V; + if (Ty->isFloatTy()) + V = Op->getValueAPF().convertToFloat(); + else if (Ty->isDoubleTy()) + V = Op->getValueAPF().convertToDouble(); + else { + bool unused; + APFloat APF = Op->getValueAPF(); + APF.convert(APFloat::IEEEdouble, APFloat::rmNearestTiesToEven, &unused); + V = APF.convertToDouble(); + } + + switch (F->getIntrinsicID()) { + default: break; + case Intrinsic::fabs: + return ConstantFoldFP(fabs, V, Ty); + case Intrinsic::log2: + return ConstantFoldFP(log2, V, Ty); + case Intrinsic::log: + return ConstantFoldFP(log, V, Ty); + case Intrinsic::log10: + return ConstantFoldFP(log10, V, Ty); + case Intrinsic::exp: + return ConstantFoldFP(exp, V, Ty); + case Intrinsic::exp2: + return ConstantFoldFP(exp2, V, Ty); + case Intrinsic::floor: + return ConstantFoldFP(floor, V, Ty); + } + switch (Name[0]) { case 'a': if (Name == "acos" && TLI->has(LibFunc::acos)) @@ -1284,7 +1337,7 @@ llvm::ConstantFoldCall(Function *F, ArrayRef Operands, else if (Name == "log10" && V > 0 && TLI->has(LibFunc::log10)) return ConstantFoldFP(log10, V, Ty); else if (F->getIntrinsicID() == Intrinsic::sqrt && - (Ty->isFloatTy() || Ty->isDoubleTy())) { + (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy())) { if (V >= -0.0) return ConstantFoldFP(sqrt, V, Ty); else // Undefined @@ -1376,18 +1429,35 @@ llvm::ConstantFoldCall(Function *F, ArrayRef Operands, if (Operands.size() == 2) { if (ConstantFP *Op1 = dyn_cast(Operands[0])) { - if (!Ty->isFloatTy() && !Ty->isDoubleTy()) + if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy()) return 0; - double Op1V = Ty->isFloatTy() ? - (double)Op1->getValueAPF().convertToFloat() : - Op1->getValueAPF().convertToDouble(); + double Op1V; + if (Ty->isFloatTy()) + Op1V = Op1->getValueAPF().convertToFloat(); + else if (Ty->isDoubleTy()) + Op1V = Op1->getValueAPF().convertToDouble(); + else { + bool unused; + APFloat APF = Op1->getValueAPF(); + APF.convert(APFloat::IEEEdouble, APFloat::rmNearestTiesToEven, &unused); + Op1V = APF.convertToDouble(); + } + if (ConstantFP *Op2 = dyn_cast(Operands[1])) { if (Op2->getType() != Op1->getType()) return 0; - double Op2V = Ty->isFloatTy() ? - (double)Op2->getValueAPF().convertToFloat(): - Op2->getValueAPF().convertToDouble(); + double Op2V; + if (Ty->isFloatTy()) + Op2V = Op2->getValueAPF().convertToFloat(); + else if (Ty->isDoubleTy()) + Op2V = Op2->getValueAPF().convertToDouble(); + else { + bool unused; + APFloat APF = Op2->getValueAPF(); + APF.convert(APFloat::IEEEdouble, APFloat::rmNearestTiesToEven, &unused); + Op2V = APF.convertToDouble(); + } if (F->getIntrinsicID() == Intrinsic::pow) { return ConstantFoldBinaryFP(pow, Op1V, Op2V, Ty); @@ -1401,6 +1471,10 @@ llvm::ConstantFoldCall(Function *F, ArrayRef Operands, if (Name == "atan2" && TLI->has(LibFunc::atan2)) return ConstantFoldBinaryFP(atan2, Op1V, Op2V, Ty); } else if (ConstantInt *Op2C = dyn_cast(Operands[1])) { + if (F->getIntrinsicID() == Intrinsic::powi && Ty->isHalfTy()) + return ConstantFP::get(F->getContext(), + APFloat((float)std::pow((float)Op1V, + (int)Op2C->getZExtValue()))); if (F->getIntrinsicID() == Intrinsic::powi && Ty->isFloatTy()) return ConstantFP::get(F->getContext(), APFloat((float)std::pow((float)Op1V, diff --git a/llvm/test/Transforms/ConstProp/half.ll b/llvm/test/Transforms/ConstProp/half.ll new file mode 100644 index 0000000..3d246d8 --- /dev/null +++ b/llvm/test/Transforms/ConstProp/half.ll @@ -0,0 +1,42 @@ +; RUN: opt -constprop -S < %s | FileCheck %s + +; CHECK: fabs_call +define half @fabs_call() { +; CHECK: ret half 0xH5140 + %x = call half @llvm.fabs.f16(half -42.0) + ret half %x +} +declare half @llvm.fabs.f16(half %x) + +; CHECK: exp_call +define half @exp_call() { +; CHECK: ret half 0xH4170 + %x = call half @llvm.exp.f16(half 1.0) + ret half %x +} +declare half @llvm.exp.f16(half %x) + +; CHECK: sqrt_call +define half @sqrt_call() { +; CHECK: ret half 0xH4000 + %x = call half @llvm.sqrt.f16(half 4.0) + ret half %x +} +declare half @llvm.sqrt.f16(half %x) + +; CHECK: floor_call +define half @floor_call() { +; CHECK: ret half 0xH4000 + %x = call half @llvm.floor.f16(half 2.5) + ret half %x +} +declare half @llvm.floor.f16(half %x) + +; CHECK: pow_call +define half @pow_call() { +; CHECK: ret half 0xH4400 + %x = call half @llvm.pow.f16(half 2.0, half 2.0) + ret half %x +} +declare half @llvm.pow.f16(half %x, half %y) + -- 2.7.4