[RISCV] Custom lowering of SET_ROUNDING
authorSerge Pavlov <sepavloff@gmail.com>
Tue, 10 Nov 2020 16:51:34 +0000 (23:51 +0700)
committerSerge Pavlov <sepavloff@gmail.com>
Thu, 22 Apr 2021 08:04:55 +0000 (15:04 +0700)
Differential Revision: https://reviews.llvm.org/D91242

llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.h
llvm/test/CodeGen/RISCV/fpenv.ll

index 0e060ce..4ca9319 100644 (file)
@@ -374,6 +374,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
 
   if (Subtarget.hasStdExtF()) {
     setOperationAction(ISD::FLT_ROUNDS_, XLenVT, Custom);
+    setOperationAction(ISD::SET_ROUNDING, MVT::Other, Custom);
   }
 
   setOperationAction(ISD::GlobalAddress, XLenVT, Custom);
@@ -2167,6 +2168,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     return lowerMSCATTER(Op, DAG);
   case ISD::FLT_ROUNDS_:
     return lowerGET_ROUNDING(Op, DAG);
+  case ISD::SET_ROUNDING:
+    return lowerSET_ROUNDING(Op, DAG);
   }
 }
 
@@ -4144,6 +4147,36 @@ SDValue RISCVTargetLowering::lowerGET_ROUNDING(SDValue Op,
   return DAG.getMergeValues({Masked, Chain}, DL);
 }
 
+SDValue RISCVTargetLowering::lowerSET_ROUNDING(SDValue Op,
+                                               SelectionDAG &DAG) const {
+  const MVT XLenVT = Subtarget.getXLenVT();
+  SDLoc DL(Op);
+  SDValue Chain = Op->getOperand(0);
+  SDValue RMValue = Op->getOperand(1);
+  SDValue SysRegNo = DAG.getConstant(
+      RISCVSysReg::lookupSysRegByName("FRM")->Encoding, DL, XLenVT);
+
+  // Encoding used for rounding mode in RISCV differs from that used in
+  // FLT_ROUNDS. To convert it the C rounding mode is used as an index in
+  // a table, which consists of a sequence of 4-bit fields, each representing
+  // corresponding RISCV mode.
+  static const unsigned Table =
+      (RISCVFPRndMode::RNE << 4 * int(RoundingMode::NearestTiesToEven)) |
+      (RISCVFPRndMode::RTZ << 4 * int(RoundingMode::TowardZero)) |
+      (RISCVFPRndMode::RDN << 4 * int(RoundingMode::TowardNegative)) |
+      (RISCVFPRndMode::RUP << 4 * int(RoundingMode::TowardPositive)) |
+      (RISCVFPRndMode::RMM << 4 * int(RoundingMode::NearestTiesToAway));
+
+  SDValue Shift = DAG.getNode(ISD::SHL, DL, XLenVT, RMValue,
+                              DAG.getConstant(2, DL, XLenVT));
+  SDValue Shifted = DAG.getNode(ISD::SRL, DL, XLenVT,
+                                DAG.getConstant(Table, DL, XLenVT), Shift);
+  RMValue = DAG.getNode(ISD::AND, DL, XLenVT, Shifted,
+                        DAG.getConstant(0x7, DL, XLenVT));
+  return DAG.getNode(RISCVISD::WRITE_CSR, DL, MVT::Other, Chain, SysRegNo,
+                     RMValue);
+}
+
 // Returns the opcode of the target-specific SDNode that implements the 32-bit
 // form of the given Opcode.
 static RISCVISD::NodeType getRISCVWOpcode(unsigned Opcode) {
index 7a09b72..ddccdb2 100644 (file)
@@ -534,6 +534,7 @@ private:
   SDValue lowerFixedLengthVectorExtendToRVV(SDValue Op, SelectionDAG &DAG,
                                             unsigned ExtendOpc) const;
   SDValue lowerGET_ROUNDING(SDValue Op, SelectionDAG &DAG) const;
+  SDValue lowerSET_ROUNDING(SDValue Op, SelectionDAG &DAG) const;
 
   bool isEligibleForTailCallOptimization(
       CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
index ed62d75..28fac83 100644 (file)
@@ -26,4 +26,100 @@ define i32 @func_01() {
   ret i32 %rm
 }
 
+define void @func_02(i32 %rm) {
+; RV32IF-LABEL: func_02:
+; RV32IF:       # %bb.0:
+; RV32IF-NEXT:    slli a0, a0, 2
+; RV32IF-NEXT:    lui a1, 66
+; RV32IF-NEXT:    addi a1, a1, 769
+; RV32IF-NEXT:    srl a0, a1, a0
+; RV32IF-NEXT:    andi a0, a0, 7
+; RV32IF-NEXT:    fsrm a0
+; RV32IF-NEXT:    ret
+;
+; RV64IF-LABEL: func_02:
+; RV64IF:       # %bb.0:
+; RV64IF-NEXT:    slli a0, a0, 32
+; RV64IF-NEXT:    srli a0, a0, 30
+; RV64IF-NEXT:    lui a1, 66
+; RV64IF-NEXT:    addiw a1, a1, 769
+; RV64IF-NEXT:    srl a0, a1, a0
+; RV64IF-NEXT:    andi a0, a0, 7
+; RV64IF-NEXT:    fsrm a0
+; RV64IF-NEXT:    ret
+  call void @llvm.set.rounding(i32 %rm)
+  ret void
+}
+
+define void @func_03() {
+; RV32IF-LABEL: func_03:
+; RV32IF:       # %bb.0:
+; RV32IF-NEXT:    fsrmi 1
+; RV32IF-NEXT:    ret
+;
+; RV64IF-LABEL: func_03:
+; RV64IF:       # %bb.0:
+; RV64IF-NEXT:    fsrmi 1
+; RV64IF-NEXT:    ret
+  call void @llvm.set.rounding(i32 0)
+  ret void
+}
+
+define void @func_04() {
+; RV32IF-LABEL: func_04:
+; RV32IF:       # %bb.0:
+; RV32IF-NEXT:    fsrmi 0
+; RV32IF-NEXT:    ret
+;
+; RV64IF-LABEL: func_04:
+; RV64IF:       # %bb.0:
+; RV64IF-NEXT:    fsrmi 0
+; RV64IF-NEXT:    ret
+  call void @llvm.set.rounding(i32 1)
+  ret void
+}
+
+define void @func_05() {
+; RV32IF-LABEL: func_05:
+; RV32IF:       # %bb.0:
+; RV32IF-NEXT:    fsrmi 3
+; RV32IF-NEXT:    ret
+;
+; RV64IF-LABEL: func_05:
+; RV64IF:       # %bb.0:
+; RV64IF-NEXT:    fsrmi 3
+; RV64IF-NEXT:    ret
+  call void @llvm.set.rounding(i32 2)
+  ret void
+}
+
+define void @func_06() {
+; RV32IF-LABEL: func_06:
+; RV32IF:       # %bb.0:
+; RV32IF-NEXT:    fsrmi 2
+; RV32IF-NEXT:    ret
+;
+; RV64IF-LABEL: func_06:
+; RV64IF:       # %bb.0:
+; RV64IF-NEXT:    fsrmi 2
+; RV64IF-NEXT:    ret
+  call void @llvm.set.rounding(i32 3)
+  ret void
+}
+
+define void @func_07() {
+; RV32IF-LABEL: func_07:
+; RV32IF:       # %bb.0:
+; RV32IF-NEXT:    fsrmi 4
+; RV32IF-NEXT:    ret
+;
+; RV64IF-LABEL: func_07:
+; RV64IF:       # %bb.0:
+; RV64IF-NEXT:    fsrmi 4
+; RV64IF-NEXT:    ret
+  call void @llvm.set.rounding(i32 4)
+  ret void
+}
+
+declare void @llvm.set.rounding(i32)
 declare i32 @llvm.flt.rounds()