[RISCV] Merge getLoadFP*Imm into a single function.
authorCraig Topper <craig.topper@sifive.com>
Tue, 14 Mar 2023 20:04:44 +0000 (13:04 -0700)
committerCraig Topper <craig.topper@sifive.com>
Tue, 14 Mar 2023 20:11:11 +0000 (13:11 -0700)
We currently have 3 functions and 3 lookup tables. This was the
most expediant and obvious way to fix several bugs.

This patch uses a single function and single lookup
table. It uses APFloat::convert to convert from the half or double
to single precision. If the conversion doesn't have any errors or
lose any information we use the f32 table to finish the lookup.

Reviewed By: asb

Differential Revision: https://reviews.llvm.org/D145897

llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp
llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVInstrInfoZfa.td

index 783c5eb..e8edc96 100644 (file)
@@ -496,7 +496,7 @@ public:
       return isUImm5();
     if (Kind != KindTy::FPImmediate)
       return false;
-    int Idx = RISCVLoadFPImm::getLoadFP64Imm(
+    int Idx = RISCVLoadFPImm::getLoadFPImm(
         APFloat(APFloat::IEEEdouble(), APInt(64, getFPConst())));
     // Don't allow decimal version of the minimum value. It is a different value
     // for each supported data type.
@@ -985,7 +985,7 @@ public:
       return;
     }
 
-    int Imm = RISCVLoadFPImm::getLoadFP64Imm(
+    int Imm = RISCVLoadFPImm::getLoadFPImm(
         APFloat(APFloat::IEEEdouble(), APInt(64, getFPConst())));
     Inst.addOperand(MCOperand::createImm(Imm));
   }
index 3f9003c..98c8e88 100644 (file)
@@ -214,87 +214,37 @@ bool RISCVRVC::uncompress(MCInst &OutInst, const MCInst &MI,
   return uncompressInst(OutInst, MI, STI);
 }
 
-// Lookup table for fli.h for entries 1-31. Entry 0(-1.0) is handled separately.
-// NOTE: The exponent for entry 1 is larger than entry 2 and 3 because they
-// are denormals.
-static constexpr std::pair<uint8_t, uint8_t> LoadFP16ImmArr[] = {
-    {0b00001, 0b00}, {0b00000, 0b01}, {0b00000, 0b10}, {0b00111, 0b00},
-    {0b01000, 0b00}, {0b01011, 0b00}, {0b01100, 0b00}, {0b01101, 0b00},
-    {0b01101, 0b01}, {0b01101, 0b10}, {0b01101, 0b11}, {0b01110, 0b00},
-    {0b01110, 0b01}, {0b01110, 0b10}, {0b01110, 0b11}, {0b01111, 0b00},
-    {0b01111, 0b01}, {0b01111, 0b10}, {0b01111, 0b11}, {0b10000, 0b00},
-    {0b10000, 0b01}, {0b10000, 0b10}, {0b10001, 0b00}, {0b10010, 0b00},
-    {0b10011, 0b00}, {0b10110, 0b00}, {0b10111, 0b00}, {0b11110, 0b00},
-    {0b11111, 0b00}, {0b11111, 0b00}, {0b11111, 0b10},
-};
-
-// Lookup table for fli.s for entries 1-31.
+// Lookup table for fli.s for entries 2-31.
 static constexpr std::pair<uint8_t, uint8_t> LoadFP32ImmArr[] = {
-    {0b00000001, 0b00}, {0b01101111, 0b00}, {0b01110000, 0b00},
-    {0b01110111, 0b00}, {0b01111000, 0b00}, {0b01111011, 0b00},
-    {0b01111100, 0b00}, {0b01111101, 0b00}, {0b01111101, 0b01},
-    {0b01111101, 0b10}, {0b01111101, 0b11}, {0b01111110, 0b00},
-    {0b01111110, 0b01}, {0b01111110, 0b10}, {0b01111110, 0b11},
-    {0b01111111, 0b00}, {0b01111111, 0b01}, {0b01111111, 0b10},
-    {0b01111111, 0b11}, {0b10000000, 0b00}, {0b10000000, 0b01},
-    {0b10000000, 0b10}, {0b10000001, 0b00}, {0b10000010, 0b00},
-    {0b10000011, 0b00}, {0b10000110, 0b00}, {0b10000111, 0b00},
-    {0b10001110, 0b00}, {0b10001111, 0b00}, {0b11111111, 0b00},
-    {0b11111111, 0b10},
-};
-
-// Lookup table for fli.d for entries 1-31.
-static constexpr std::pair<uint16_t, uint8_t> LoadFP64ImmArr[] = {
-    {0b00000000001, 0b00}, {0b01111101111, 0b00}, {0b01111110000, 0b00},
-    {0b01111110111, 0b00}, {0b01111111000, 0b00}, {0b01111111011, 0b00},
-    {0b01111111100, 0b00}, {0b01111111101, 0b00}, {0b01111111101, 0b01},
-    {0b01111111101, 0b10}, {0b01111111101, 0b11}, {0b01111111110, 0b00},
-    {0b01111111110, 0b01}, {0b01111111110, 0b10}, {0b01111111110, 0b11},
-    {0b01111111111, 0b00}, {0b01111111111, 0b01}, {0b01111111111, 0b10},
-    {0b01111111111, 0b11}, {0b10000000000, 0b00}, {0b10000000000, 0b01},
-    {0b10000000000, 0b10}, {0b10000000001, 0b00}, {0b10000000010, 0b00},
-    {0b10000000011, 0b00}, {0b10000000110, 0b00}, {0b10000000111, 0b00},
-    {0b10000001110, 0b00}, {0b10000001111, 0b00}, {0b11111111111, 0b00},
-    {0b11111111111, 0b10},
+    {0b01101111, 0b00}, {0b01110000, 0b00}, {0b01110111, 0b00},
+    {0b01111000, 0b00}, {0b01111011, 0b00}, {0b01111100, 0b00},
+    {0b01111101, 0b00}, {0b01111101, 0b01}, {0b01111101, 0b10},
+    {0b01111101, 0b11}, {0b01111110, 0b00}, {0b01111110, 0b01},
+    {0b01111110, 0b10}, {0b01111110, 0b11}, {0b01111111, 0b00},
+    {0b01111111, 0b01}, {0b01111111, 0b10}, {0b01111111, 0b11},
+    {0b10000000, 0b00}, {0b10000000, 0b01}, {0b10000000, 0b10},
+    {0b10000001, 0b00}, {0b10000010, 0b00}, {0b10000011, 0b00},
+    {0b10000110, 0b00}, {0b10000111, 0b00}, {0b10001110, 0b00},
+    {0b10001111, 0b00}, {0b11111111, 0b00}, {0b11111111, 0b10},
 };
 
-int RISCVLoadFPImm::getLoadFP16Imm(const APFloat &FPImm) {
-  assert(&FPImm.getSemantics() == &APFloat::IEEEhalf());
-
-  APInt Imm = FPImm.bitcastToAPInt();
-
-  if (Imm.extractBitsAsZExtValue(8, 0) != 0)
+int RISCVLoadFPImm::getLoadFPImm(APFloat FPImm) {
+  assert((&FPImm.getSemantics() == &APFloat::IEEEsingle() ||
+          &FPImm.getSemantics() == &APFloat::IEEEdouble() ||
+          &FPImm.getSemantics() == &APFloat::IEEEhalf()) &&
+         "Unexpected semantics");
+
+  // Handle the minimum normalized value which is different for each type.
+  if (FPImm.isSmallestNormalized())
+    return 1;
+
+  // Convert to single precision to use its lookup table.
+  bool LosesInfo;
+  APFloat::opStatus Status = FPImm.convert(
+      APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &LosesInfo);
+  if (Status != APFloat::opOK || LosesInfo)
     return -1;
 
-  bool Sign = Imm.extractBitsAsZExtValue(1, 15);
-  uint8_t Mantissa = Imm.extractBitsAsZExtValue(2, 8);
-  uint8_t Exp = Imm.extractBitsAsZExtValue(5, 10);
-
-  // The array isn't sorted so we must use std::find unlike fp32 and fp64.
-  auto EMI = llvm::find(LoadFP16ImmArr, std::make_pair(Exp, Mantissa));
-  if (EMI == std::end(LoadFP16ImmArr))
-    return -1;
-
-  // Table doesn't have entry 0.
-  int Entry = std::distance(std::begin(LoadFP16ImmArr), EMI) + 1;
-
-  // The only legal negative value is -1.0(entry 0). 1.0 is entry 16.
-  if (Sign) {
-    if (Entry == 16)
-      return 0;
-    return false;
-  }
-
-  // Entry 29 and 30 are both infinity, but 30 is the real infinity.
-  if (Entry == 29)
-    ++Entry;
-
-  return Entry;
-}
-
-int RISCVLoadFPImm::getLoadFP32Imm(const APFloat &FPImm) {
-  assert(&FPImm.getSemantics() == &APFloat::IEEEsingle());
-
   APInt Imm = FPImm.bitcastToAPInt();
 
   if (Imm.extractBitsAsZExtValue(21, 0) != 0)
@@ -309,38 +259,8 @@ int RISCVLoadFPImm::getLoadFP32Imm(const APFloat &FPImm) {
       EMI->second != Mantissa)
     return -1;
 
-  // Table doesn't have entry 0.
-  int Entry = std::distance(std::begin(LoadFP32ImmArr), EMI) + 1;
-
-  // The only legal negative value is -1.0(entry 0). 1.0 is entry 16.
-  if (Sign) {
-    if (Entry == 16)
-      return 0;
-    return false;
-  }
-
-  return Entry;
-}
-
-int RISCVLoadFPImm::getLoadFP64Imm(const APFloat &FPImm) {
-  assert(&FPImm.getSemantics() == &APFloat::IEEEdouble());
-
-  APInt Imm = FPImm.bitcastToAPInt();
-
-  if (Imm.extractBitsAsZExtValue(50, 0) != 0)
-    return -1;
-
-  bool Sign = Imm.extractBitsAsZExtValue(1, 63);
-  uint8_t Mantissa = Imm.extractBitsAsZExtValue(2, 50);
-  uint16_t Exp = Imm.extractBitsAsZExtValue(11, 52);
-
-  auto EMI = llvm::lower_bound(LoadFP64ImmArr, std::make_pair(Exp, Mantissa));
-  if (EMI == std::end(LoadFP64ImmArr) || EMI->first != Exp ||
-      EMI->second != Mantissa)
-    return -1;
-
-  // Table doesn't have entry 0.
-  int Entry = std::distance(std::begin(LoadFP64ImmArr), EMI) + 1;
+  // Table doesn't have entry 0 or 1.
+  int Entry = std::distance(std::begin(LoadFP32ImmArr), EMI) + 2;
 
   // The only legal negative value is -1.0(entry 0). 1.0 is entry 16.
   if (Sign) {
@@ -362,8 +282,8 @@ float RISCVLoadFPImm::getFPImm(unsigned Imm) {
     Imm = 16;
   }
 
-  uint32_t Exp = LoadFP32ImmArr[Imm - 1].first;
-  uint32_t Mantissa = LoadFP32ImmArr[Imm - 1].second;
+  uint32_t Exp = LoadFP32ImmArr[Imm - 2].first;
+  uint32_t Mantissa = LoadFP32ImmArr[Imm - 2].second;
 
   uint32_t I = Sign << 31 | Exp << 23 | Mantissa << 21;
   return bit_cast<float>(I);
index cdb972b..70fdc0e 100644 (file)
@@ -349,20 +349,10 @@ inline static bool isValidRoundingMode(unsigned Mode) {
 namespace RISCVLoadFPImm {
 float getFPImm(unsigned Imm);
 
-/// getLoadFP32Imm - Return a 5-bit binary encoding of the 32-bit
-/// floating-point immediate value. If the value cannot be represented as a
-/// 5-bit binary encoding, then return -1.
-int getLoadFP32Imm(const APFloat &FPImm);
-
-/// getLoadFP64Imm - Return a 5-bit binary encoding of the 64-bit
-/// floating-point immediate value. If the value cannot be represented as a
-/// 5-bit binary encoding, then return -1.
-int getLoadFP64Imm(const APFloat &FPImm);
-
-/// getLoadFP16Imm - Return a 5-bit binary encoding of the 16-bit
-/// floating-point immediate value. If the value cannot be represented as a
-/// 5-bit binary encoding, then return -1.
-int getLoadFP16Imm(const APFloat &FPImm);
+/// getLoadFPImm - Return a 5-bit binary encoding of the floating-point
+/// immediate value. If the value cannot be represented as a 5-bit binary
+/// encoding, then return -1.
+int getLoadFPImm(APFloat FPImm);
 } // namespace RISCVLoadFPImm
 
 namespace RISCVSysReg {
index 8f68dab..c3c63ff 100644 (file)
@@ -1540,21 +1540,20 @@ bool RISCVTargetLowering::isOffsetFoldingLegal(
 }
 
 bool RISCVTargetLowering::isLegalZfaFPImm(const APFloat &Imm, EVT VT) const {
-  if (!Subtarget.hasStdExtZfa() || !VT.isSimple())
+  if (!Subtarget.hasStdExtZfa())
     return false;
 
-  switch (VT.getSimpleVT().SimpleTy) {
-  default:
-    return false;
-  case MVT::f16:
-    return (Subtarget.hasStdExtZfh() || Subtarget.hasStdExtZvfh()) &&
-           RISCVLoadFPImm::getLoadFP16Imm(Imm) != -1;
-  case MVT::f32:
-    return RISCVLoadFPImm::getLoadFP32Imm(Imm) != -1;
-  case MVT::f64:
+  bool IsSupportedVT = false;
+  if (VT == MVT::f16) {
+    IsSupportedVT = Subtarget.hasStdExtZfh() || Subtarget.hasStdExtZvfh();
+  } else if (VT == MVT::f32) {
+    IsSupportedVT = true;
+  } else if (VT == MVT::f64) {
     assert(Subtarget.hasStdExtD() && "Expect D extension");
-    return RISCVLoadFPImm::getLoadFP64Imm(Imm) != -1;
+    IsSupportedVT = true;
   }
+
+  return IsSupportedVT && RISCVLoadFPImm::getLoadFPImm(Imm) != -1;
 }
 
 bool RISCVTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT,
index 96982a4..bac6422 100644 (file)
@@ -179,20 +179,13 @@ def : InstAlias<"fgeq.h $rd, $rs, $rt",
 // Codegen patterns
 //===----------------------------------------------------------------------===//
 
-def fp32imm_to_loadfpimm : SDNodeXForm<fpimm, [{
-  return CurDAG->getTargetConstant(RISCVLoadFPImm::getLoadFP32Imm(N->getValueAPF()),
+def fpimm_to_loadfpimm : SDNodeXForm<fpimm, [{
+  return CurDAG->getTargetConstant(RISCVLoadFPImm::getLoadFPImm(N->getValueAPF()),
                                    SDLoc(N), Subtarget->getXLenVT());}]>;
 
-def fp64imm_to_loadfpimm : SDNodeXForm<fpimm, [{
-  return CurDAG->getTargetConstant(RISCVLoadFPImm::getLoadFP64Imm(N->getValueAPF()),
-                                   SDLoc(N), Subtarget->getXLenVT());}]>;
-
-def fp16imm_to_loadfpimm : SDNodeXForm<fpimm, [{
-  return CurDAG->getTargetConstant(RISCVLoadFPImm::getLoadFP16Imm(N->getValueAPF()),
-                                   SDLoc(N), Subtarget->getXLenVT());}]>;
 
 let Predicates = [HasStdExtZfa] in {
-def : Pat<(f32 fpimm:$imm), (FLI_S (fp32imm_to_loadfpimm fpimm:$imm))>;
+def : Pat<(f32 fpimm:$imm), (FLI_S (fpimm_to_loadfpimm fpimm:$imm))>;
 
 def: PatFprFpr<fminimum, FMINM_S, FPR32>;
 def: PatFprFpr<fmaximum, FMAXM_S, FPR32>;
@@ -216,7 +209,7 @@ def: PatSetCC<FPR32, strict_fsetcc, SETOLE, FLEQ_S>;
 } // Predicates = [HasStdExtZfa]
 
 let Predicates = [HasStdExtZfa, HasStdExtD] in {
-def : Pat<(f64 fpimm:$imm), (FLI_D (fp64imm_to_loadfpimm fpimm:$imm))>;
+def : Pat<(f64 fpimm:$imm), (FLI_D (fpimm_to_loadfpimm fpimm:$imm))>;
 
 def: PatFprFpr<fminimum, FMINM_D, FPR64>;
 def: PatFprFpr<fmaximum, FMAXM_D, FPR64>;
@@ -246,7 +239,7 @@ def : Pat<(RISCVBuildPairF64 GPR:$rs1, GPR:$rs2),
 }
 
 let Predicates = [HasStdExtZfa, HasStdExtZfh] in {
-def : Pat<(f16 fpimm:$imm), (FLI_H (fp16imm_to_loadfpimm fpimm:$imm))>;
+def : Pat<(f16 fpimm:$imm), (FLI_H (fpimm_to_loadfpimm fpimm:$imm))>;
 
 def: PatFprFpr<fminimum, FMINM_H, FPR16>;
 def: PatFprFpr<fmaximum, FMAXM_H, FPR16>;