ValueTracking: Add start of computeKnownFPClass API
authorMatt Arsenault <Matthew.Arsenault@amd.com>
Wed, 1 Feb 2023 14:44:06 +0000 (10:44 -0400)
committerMatt Arsenault <arsenm2@gmail.com>
Fri, 17 Mar 2023 03:14:40 +0000 (23:14 -0400)
Add a new compute-known-bits like function to compute all
the interesting floating point properties at once.

Eventually this should absorb all the various floating point
queries we already have.

llvm/include/llvm/Analysis/ValueTracking.h
llvm/lib/Analysis/ValueTracking.cpp
llvm/unittests/Analysis/ValueTrackingTest.cpp

index 041b97c..c056bb9 100644 (file)
@@ -217,6 +217,93 @@ unsigned ComputeMaxSignificantBits(const Value *Op, const DataLayout &DL,
 Intrinsic::ID getIntrinsicForCallSite(const CallBase &CB,
                                       const TargetLibraryInfo *TLI);
 
+struct KnownFPClass {
+  /// Floating-point classes the value could be one of.
+  FPClassTest KnownFPClasses = fcAllFlags;
+
+  /// std::nullopt if the sign bit is unknown, true if the sign bit is
+  /// definitely set or false if the sign bit is definitely unset.
+  std::optional<bool> SignBit;
+
+  KnownFPClass &operator|=(const KnownFPClass &RHS) {
+    KnownFPClasses = KnownFPClasses | RHS.KnownFPClasses;
+
+    if (SignBit != RHS.SignBit)
+      SignBit = std::nullopt;
+    return *this;
+  }
+
+  void knownNot(FPClassTest RuleOut) {
+    KnownFPClasses = KnownFPClasses & ~RuleOut;
+  }
+
+  void fneg() {
+    KnownFPClasses = llvm::fneg(KnownFPClasses);
+    if (SignBit)
+      SignBit = !*SignBit;
+  }
+
+  void fabs() {
+    KnownFPClasses = llvm::fabs(KnownFPClasses);
+    SignBit = false;
+  }
+
+  /// Assume the sign bit is zero.
+  void signBitIsZero() {
+    KnownFPClasses = (KnownFPClasses & fcPositive) |
+                     (KnownFPClasses & fcNan);
+    SignBit = false;
+  }
+
+  void copysign(const KnownFPClass &Sign) {
+    // Start assuming nothing about the sign.
+    SignBit = Sign.SignBit;
+    if (!SignBit)
+      return;
+
+    if (*SignBit)
+      KnownFPClasses = KnownFPClasses & fcNegative;
+    else
+      KnownFPClasses = KnownFPClasses & fcPositive;
+  }
+
+  void resetAll() { *this = KnownFPClass(); }
+};
+
+inline KnownFPClass operator|(KnownFPClass LHS, const KnownFPClass &RHS) {
+  LHS |= RHS;
+  return LHS;
+}
+
+inline KnownFPClass operator|(const KnownFPClass &LHS, KnownFPClass &&RHS) {
+  RHS |= LHS;
+  return std::move(RHS);
+}
+
+/// Determine which floating-point classes are valid for \p V, and return them
+/// in KnownFPClass bit sets.
+///
+/// This function is defined on values with floating-point type, values vectors
+/// of floating-point type, and arrays of floating-point type.
+
+/// \p InterestedClasses is a compile time optimization hint for which floating
+/// point classes should be queried. Queries not specified in \p
+/// InterestedClasses should be reliable if they are determined during the
+/// query.
+KnownFPClass computeKnownFPClass(
+    const Value *V, const APInt &DemandedElts, const DataLayout &DL,
+    FPClassTest InterestedClasses = fcAllFlags, unsigned Depth = 0,
+    const TargetLibraryInfo *TLI = nullptr, AssumptionCache *AC = nullptr,
+    const Instruction *CxtI = nullptr, const DominatorTree *DT = nullptr,
+    OptimizationRemarkEmitter *ORE = nullptr, bool UseInstrInfo = true);
+
+KnownFPClass computeKnownFPClass(
+    const Value *V, const DataLayout &DL,
+    FPClassTest InterestedClasses = fcAllFlags, unsigned Depth = 0,
+    const TargetLibraryInfo *TLI = nullptr, AssumptionCache *AC = nullptr,
+    const Instruction *CxtI = nullptr, const DominatorTree *DT = nullptr,
+    OptimizationRemarkEmitter *ORE = nullptr, bool UseInstrInfo = true);
+
 /// Return true if we can prove that the specified FP value is never equal to
 /// -0.0.
 bool CannotBeNegativeZero(const Value *V, const TargetLibraryInfo *TLI,
index 1e9ce10..970218a 100644 (file)
@@ -16,6 +16,7 @@
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
@@ -4124,6 +4125,153 @@ bool llvm::isKnownNeverNaN(const Value *V, const TargetLibraryInfo *TLI,
   return false;
 }
 
+// TODO: Merge implementations of isKnownNeverNaN, isKnownNeverInfinity,
+// CannotBeNegativeZero, cannotBeOrderedLessThanZero into here.
+void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
+                         FPClassTest InterestedClasses, KnownFPClass &Known,
+                         unsigned Depth, const Query &Q,
+                         const TargetLibraryInfo *TLI) {
+  if (!DemandedElts) {
+    // No demanded elts, better to assume we don't know anything.
+    Known.resetAll();
+    return;
+  }
+
+  assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
+
+  const APFloat *C;
+  if (match(V, m_APFloat(C))) {
+    // We know all of the classes for a scalar constant or a splat vector
+    // constant!
+    Known.KnownFPClasses = C->classify();
+    Known.SignBit = C->isNegative();
+    return;
+  }
+
+  const Operator *Op = dyn_cast<Operator>(V);
+  if (!Op)
+    return;
+
+  FPClassTest KnownNotFromFlags = fcNone;
+  if (const FPMathOperator *FPOp = dyn_cast<FPMathOperator>(Op)) {
+    if (FPOp->hasNoNaNs())
+      KnownNotFromFlags |= fcNan;
+    if (FPOp->hasNoInfs())
+      KnownNotFromFlags |= fcInf;
+
+    // We no longer need to find out about these bits from inputs if we can
+    // assume this from flags.
+    InterestedClasses &= ~KnownNotFromFlags;
+  }
+
+  auto ClearClassesFromFlags = make_scope_exit([=, &Known] {
+    Known.knownNot(KnownNotFromFlags);
+  });
+
+  // All recursive calls that increase depth must come after this.
+  if (Depth == MaxAnalysisRecursionDepth)
+    return;
+
+  switch (Op->getOpcode()) {
+  case Instruction::FNeg: {
+    computeKnownFPClass(Op->getOperand(0), DemandedElts, InterestedClasses,
+                        Known, Depth + 1, Q, TLI);
+    Known.fneg();
+    break;
+  }
+  case Instruction::Select: {
+    KnownFPClass Known2;
+    computeKnownFPClass(Op->getOperand(1), DemandedElts, InterestedClasses,
+                        Known, Depth + 1, Q, TLI);
+    computeKnownFPClass(Op->getOperand(2), DemandedElts, InterestedClasses,
+                        Known2, Depth + 1, Q, TLI);
+    Known |= Known2;
+    break;
+  }
+  case Instruction::Call: {
+    if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op)) {
+      switch (II->getIntrinsicID()) {
+      case Intrinsic::fabs:
+        computeKnownFPClass(II->getArgOperand(0), DemandedElts,
+                            InterestedClasses, Known, Depth + 1, Q, TLI);
+        Known.fabs();
+        break;
+      case Intrinsic::copysign: {
+        KnownFPClass KnownSign;
+
+        computeKnownFPClass(II->getArgOperand(0), DemandedElts,
+                            InterestedClasses, Known, Depth + 1, Q, TLI);
+        computeKnownFPClass(II->getArgOperand(1), DemandedElts,
+                            InterestedClasses, KnownSign, Depth + 1, Q, TLI);
+        Known.copysign(KnownSign);
+        break;
+      }
+      default:
+        break;
+      }
+    }
+
+    break;
+  }
+  case Instruction::SIToFP:
+  case Instruction::UIToFP: {
+    // Cannot produce nan
+    Known.knownNot(fcNan);
+    if (Op->getOpcode() == Instruction::UIToFP)
+      Known.signBitIsZero();
+
+    if (InterestedClasses & fcInf) {
+      // Get width of largest magnitude integer (remove a bit if signed).
+      // This still works for a signed minimum value because the largest FP
+      // value is scaled by some fraction close to 2.0 (1.0 + 0.xxxx).
+      int IntSize = Op->getOperand(0)->getType()->getScalarSizeInBits();
+      if (Op->getOpcode() == Instruction::SIToFP)
+        --IntSize;
+
+      // If the exponent of the largest finite FP value can hold the largest
+      // integer, the result of the cast must be finite.
+      Type *FPTy = Op->getType()->getScalarType();
+      if (ilogb(APFloat::getLargest(FPTy->getFltSemantics())) >= IntSize)
+        Known.knownNot(fcInf);
+    }
+
+    break;
+  }
+  default:
+    break;
+  }
+
+  // TODO: Handle assumes
+}
+
+KnownFPClass llvm::computeKnownFPClass(
+    const Value *V, const APInt &DemandedElts, const DataLayout &DL,
+    FPClassTest InterestedClasses, unsigned Depth, const TargetLibraryInfo *TLI,
+    AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT,
+    OptimizationRemarkEmitter *ORE, bool UseInstrInfo) {
+  KnownFPClass KnownClasses;
+  ::computeKnownFPClass(V, DemandedElts, InterestedClasses, KnownClasses, Depth,
+                        Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE),
+                        TLI);
+  return KnownClasses;
+}
+
+KnownFPClass
+llvm::computeKnownFPClass(const Value *V, const DataLayout &DL,
+                          FPClassTest InterestedClasses, unsigned Depth,
+                          const TargetLibraryInfo *TLI, AssumptionCache *AC,
+                          const Instruction *CxtI, const DominatorTree *DT,
+                          OptimizationRemarkEmitter *ORE, bool UseInstrInfo) {
+  KnownFPClass Known;
+  auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
+  APInt DemandedElts =
+      FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
+  ::computeKnownFPClass(V, DemandedElts, InterestedClasses, Known, Depth,
+                        Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE),
+                        TLI);
+  return Known;
+}
+
 Value *llvm::isBytewiseValue(Value *V, const DataLayout &DL) {
 
   // All byte-wide stores are splatable, even of arbitrary variables.
index 713cfed..0211627 100644 (file)
@@ -110,6 +110,18 @@ protected:
   }
 };
 
+class ComputeKnownFPClassTest : public ValueTrackingTest {
+protected:
+  void expectKnownFPClass(unsigned KnownTrue, std::optional<bool> SignBitKnown,
+                          Instruction *TestVal = nullptr) {
+    if (!TestVal)
+      TestVal = A;
+
+    KnownFPClass Known = computeKnownFPClass(TestVal, M->getDataLayout());
+    EXPECT_EQ(KnownTrue, Known.KnownFPClasses);
+    EXPECT_EQ(SignBitKnown, Known.SignBit);
+  }
+};
 }
 
 TEST_F(MatchSelectPatternTest, SimpleFMin) {
@@ -1258,6 +1270,199 @@ TEST_F(ComputeKnownBitsTest, ComputeKnownMulBits) {
   expectKnownBits(/*zero*/ 95u, /*one*/ 32u);
 }
 
+TEST_F(ComputeKnownFPClassTest, SelectPos0) {
+  parseAssembly(
+      "define float @test(i1 %cond) {\n"
+      "  %A = select i1 %cond, float 0.0, float 0.0"
+      "  ret float %A\n"
+      "}\n");
+  expectKnownFPClass(fcPosZero, false);
+}
+
+TEST_F(ComputeKnownFPClassTest, SelectNeg0) {
+  parseAssembly(
+      "define float @test(i1 %cond) {\n"
+      "  %A = select i1 %cond, float -0.0, float -0.0"
+      "  ret float %A\n"
+      "}\n");
+  expectKnownFPClass(fcNegZero, true);
+}
+
+TEST_F(ComputeKnownFPClassTest, SelectPosOrNeg0) {
+  parseAssembly(
+      "define float @test(i1 %cond) {\n"
+      "  %A = select i1 %cond, float 0.0, float -0.0"
+      "  ret float %A\n"
+      "}\n");
+  expectKnownFPClass(fcZero, std::nullopt);
+}
+
+TEST_F(ComputeKnownFPClassTest, SelectPosInf) {
+  parseAssembly(
+      "define float @test(i1 %cond) {\n"
+      "  %A = select i1 %cond, float 0x7FF0000000000000, float 0x7FF0000000000000"
+      "  ret float %A\n"
+      "}\n");
+  expectKnownFPClass(fcPosInf, false);
+}
+
+TEST_F(ComputeKnownFPClassTest, SelectNegInf) {
+  parseAssembly(
+      "define float @test(i1 %cond) {\n"
+      "  %A = select i1 %cond, float 0xFFF0000000000000, float 0xFFF0000000000000"
+      "  ret float %A\n"
+      "}\n");
+  expectKnownFPClass(fcNegInf, true);
+}
+
+TEST_F(ComputeKnownFPClassTest, SelectPosOrNegInf) {
+  parseAssembly(
+      "define float @test(i1 %cond) {\n"
+      "  %A = select i1 %cond, float 0x7FF0000000000000, float 0xFFF0000000000000"
+      "  ret float %A\n"
+      "}\n");
+  expectKnownFPClass(fcInf, std::nullopt);
+}
+
+TEST_F(ComputeKnownFPClassTest, SelectNNaN) {
+  parseAssembly(
+      "define float @test(i1 %cond, float %arg0, float %arg1) {\n"
+      "  %A = select nnan i1 %cond, float %arg0, float %arg1"
+      "  ret float %A\n"
+      "}\n");
+  expectKnownFPClass(~fcNan, std::nullopt);
+}
+
+TEST_F(ComputeKnownFPClassTest, SelectNInf) {
+  parseAssembly(
+      "define float @test(i1 %cond, float %arg0, float %arg1) {\n"
+      "  %A = select ninf i1 %cond, float %arg0, float %arg1"
+      "  ret float %A\n"
+      "}\n");
+  expectKnownFPClass(~fcInf, std::nullopt);
+}
+
+TEST_F(ComputeKnownFPClassTest, SelectNNaNNInf) {
+  parseAssembly(
+      "define float @test(i1 %cond, float %arg0, float %arg1) {\n"
+      "  %A = select nnan ninf i1 %cond, float %arg0, float %arg1"
+      "  ret float %A\n"
+      "}\n");
+  expectKnownFPClass(~(fcNan | fcInf), std::nullopt);
+}
+
+TEST_F(ComputeKnownFPClassTest, FNegNInf) {
+  parseAssembly(
+      "define float @test(float %arg) {\n"
+      "  %A = fneg ninf float %arg"
+      "  ret float %A\n"
+      "}\n");
+  expectKnownFPClass(~fcInf, std::nullopt);
+}
+
+TEST_F(ComputeKnownFPClassTest, FabsUnknown) {
+  parseAssembly(
+      "declare float @llvm.fabs.f32(float)"
+      "define float @test(float %arg) {\n"
+      "  %A = call float @llvm.fabs.f32(float %arg)"
+      "  ret float %A\n"
+      "}\n");
+  expectKnownFPClass(fcAllFlags, false);
+}
+
+TEST_F(ComputeKnownFPClassTest, FNegFabsUnknown) {
+  parseAssembly(
+      "declare float @llvm.fabs.f32(float)"
+      "define float @test(float %arg) {\n"
+      "  %fabs = call float @llvm.fabs.f32(float %arg)"
+      "  %A = fneg float %fabs"
+      "  ret float %A\n"
+      "}\n");
+  expectKnownFPClass(fcAllFlags, true);
+}
+
+TEST_F(ComputeKnownFPClassTest, NegFabsNInf) {
+  parseAssembly(
+      "declare float @llvm.fabs.f32(float)"
+      "define float @test(float %arg) {\n"
+      "  %fabs = call ninf float @llvm.fabs.f32(float %arg)"
+      "  %A = fneg float %fabs"
+      "  ret float %A\n"
+      "}\n");
+  expectKnownFPClass(~fcInf, true);
+}
+
+TEST_F(ComputeKnownFPClassTest, FNegFabsNNaN) {
+  parseAssembly(
+      "declare float @llvm.fabs.f32(float)"
+      "define float @test(float %arg) {\n"
+      "  %fabs = call nnan float @llvm.fabs.f32(float %arg)"
+      "  %A = fneg float %fabs"
+      "  ret float %A\n"
+      "}\n");
+  expectKnownFPClass(~fcNan, true);
+}
+
+TEST_F(ComputeKnownFPClassTest, CopySignNNanSrc0) {
+  parseAssembly(
+      "declare float @llvm.fabs.f32(float)\n"
+      "declare float @llvm.copysign.f32(float, float)\n"
+      "define float @test(float %arg0, float %arg1) {\n"
+      "  %fabs = call nnan float @llvm.fabs.f32(float %arg0)"
+      "  %A = call float @llvm.copysign.f32(float %fabs, float %arg1)"
+      "  ret float %A\n"
+      "}\n");
+  expectKnownFPClass(~fcNan, std::nullopt);
+}
+
+TEST_F(ComputeKnownFPClassTest, CopySignNInfSrc0_NegSign) {
+  parseAssembly(
+      "declare float @llvm.sqrt.f32(float)\n"
+      "declare float @llvm.copysign.f32(float, float)\n"
+      "define float @test(float %arg0, float %arg1) {\n"
+      "  %ninf = call ninf float @llvm.sqrt.f32(float %arg0)"
+      "  %A = call float @llvm.copysign.f32(float %ninf, float -1.0)"
+      "  ret float %A\n"
+      "}\n");
+  expectKnownFPClass(fcNegFinite, true);
+}
+
+TEST_F(ComputeKnownFPClassTest, CopySignNInfSrc0_PosSign) {
+  parseAssembly(
+      "declare float @llvm.sqrt.f32(float)\n"
+      "declare float @llvm.copysign.f32(float, float)\n"
+      "define float @test(float %arg0, float %arg1) {\n"
+      "  %ninf = call ninf float @llvm.sqrt.f32(float %arg0)"
+      "  %A = call float @llvm.copysign.f32(float %ninf, float 1.0)"
+      "  ret float %A\n"
+      "}\n");
+  expectKnownFPClass(fcPosFinite, false);
+}
+
+TEST_F(ComputeKnownFPClassTest, UIToFP) {
+  parseAssembly(
+      "define float @test(i32 %arg0, i16 %arg1) {\n"
+      "  %A = uitofp i32 %arg0 to float"
+      "  %A2 = uitofp i16 %arg1 to half"
+      "  ret float %A\n"
+      "}\n");
+  expectKnownFPClass(fcPosFinite, false, A);
+  expectKnownFPClass(fcPositive, false, A2);
+}
+
+TEST_F(ComputeKnownFPClassTest, SIToFP) {
+  parseAssembly(
+      "define float @test(i32 %arg0, i16 %arg1, i17 %arg2) {\n"
+      "  %A = sitofp i32 %arg0 to float"
+      "  %A2 = sitofp i16 %arg1 to half"
+      "  %A3 = sitofp i17 %arg2 to half"
+      "  ret float %A\n"
+      "}\n");
+  expectKnownFPClass(fcFinite, std::nullopt, A);
+  expectKnownFPClass(fcFinite, std::nullopt, A2);
+  expectKnownFPClass(~fcNan, std::nullopt, A3);
+}
+
 TEST_F(ValueTrackingTest, isNonZeroRecurrence) {
   parseAssembly(R"(
     define i1 @test(i8 %n, i8 %r) {