[RISCV] Made fsqrtv pseudoinstruction SEW-aware
authorNitin John Raj <nitin.raj@sifive.com>
Fri, 24 Feb 2023 21:45:31 +0000 (13:45 -0800)
committerNitin John Raj <nitin.raj@sifive.com>
Fri, 24 Mar 2023 23:33:25 +0000 (16:33 -0700)
llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
llvm/lib/Target/RISCV/RISCVScheduleV.td

index 327b3c5..3bf70e4 100644 (file)
@@ -2328,18 +2328,25 @@ multiclass VPseudoVCLS_V {
 multiclass VPseudoVSQR_V {
   foreach m = MxListF in {
     defvar mx = m.MX;
-    defvar WriteVFSqrtV_MX = !cast<SchedWrite>("WriteVFSqrtV_" # mx);
-    defvar ReadVFSqrtV_MX = !cast<SchedRead>("ReadVFSqrtV_" # mx);
+    defvar sews = SchedSEWSet<m.MX>.val;
 
-    let VLMul = m.value in {
-      def "_V_" # mx : VPseudoUnaryNoMask<m.vrclass, m.vrclass>,
-                       Sched<[WriteVFSqrtV_MX, ReadVFSqrtV_MX, ReadVMask]>;
-      def "_V_" # mx # "_TU": VPseudoUnaryNoMaskTU<m.vrclass, m.vrclass>,
-                              Sched<[WriteVFSqrtV_MX, ReadVFSqrtV_MX, ReadVMask]>;
-      def "_V_" # mx # "_MASK" : VPseudoUnaryMaskTA<m.vrclass, m.vrclass>,
-                                 RISCVMaskedPseudo</*MaskOpIdx*/ 2>,
-                                 Sched<[WriteVFSqrtV_MX, ReadVFSqrtV_MX, ReadVMask]>;
-    }
+    let VLMul = m.value in
+      foreach e = sews in {
+        defvar suffix = "_" # mx # "_E" # e;
+        defvar WriteVFSqrtV_MX_E = !cast<SchedWrite>("WriteVFSqrtV" # suffix);
+        defvar ReadVFSqrtV_MX_E = !cast<SchedRead>("ReadVFSqrtV" # suffix);
+
+        def "_V" # suffix : VPseudoUnaryNoMask<m.vrclass, m.vrclass>,
+                            Sched<[WriteVFSqrtV_MX_E, ReadVFSqrtV_MX_E,
+                                   ReadVMask]>;
+        def "_V" # suffix # "_TU": VPseudoUnaryNoMaskTU<m.vrclass, m.vrclass>,
+                                   Sched<[WriteVFSqrtV_MX_E, ReadVFSqrtV_MX_E,
+                                          ReadVMask]>;
+        def "_V" # suffix # "_MASK" : VPseudoUnaryMaskTA<m.vrclass, m.vrclass>,
+                                      RISCVMaskedPseudo</*MaskOpIdx*/ 2>,
+                                      Sched<[WriteVFSqrtV_MX_E, ReadVFSqrtV_MX_E,
+                                             ReadVMask]>;
+      }
   }
 }
 
@@ -3835,6 +3842,23 @@ class VPatUnaryNoMask<string intrinsic_name,
                    (op2_type op2_reg_class:$rs2),
                    GPR:$vl, sew)>;
 
+class VPatUnaryNoMask_E<string intrinsic_name,
+                        string inst,
+                        string kind,
+                        ValueType result_type,
+                        ValueType op2_type,
+                        int log2sew,
+                        LMULInfo vlmul,
+                        int sew,
+                        VReg op2_reg_class> :
+  Pat<(result_type (!cast<Intrinsic>(intrinsic_name)
+                   (result_type undef),
+                   (op2_type op2_reg_class:$rs2),
+                   VLOpFrag)),
+                   (!cast<Instruction>(inst#"_"#kind#"_"#vlmul.MX#"_E"#sew)
+                   (op2_type op2_reg_class:$rs2),
+                   GPR:$vl, log2sew)>;
+
 class VPatUnaryNoMaskTU<string intrinsic_name,
                         string inst,
                         string kind,
@@ -3853,6 +3877,25 @@ class VPatUnaryNoMaskTU<string intrinsic_name,
                    (op2_type op2_reg_class:$rs2),
                    GPR:$vl, sew)>;
 
+class VPatUnaryNoMaskTU_E<string intrinsic_name,
+                          string inst,
+                          string kind,
+                          ValueType result_type,
+                          ValueType op2_type,
+                          int log2sew,
+                          LMULInfo vlmul,
+                          int sew,
+                          VReg result_reg_class,
+                          VReg op2_reg_class> :
+  Pat<(result_type (!cast<Intrinsic>(intrinsic_name)
+                   (result_type result_reg_class:$merge),
+                   (op2_type op2_reg_class:$rs2),
+                   VLOpFrag)),
+                   (!cast<Instruction>(inst#"_"#kind#"_"#vlmul.MX#"_E"#sew#"_TU")
+                   (result_type result_reg_class:$merge),
+                   (op2_type op2_reg_class:$rs2),
+                   GPR:$vl, log2sew)>;
+
 class VPatUnaryMask<string intrinsic_name,
                     string inst,
                     string kind,
@@ -3893,6 +3936,27 @@ class VPatUnaryMaskTA<string intrinsic_name,
                    (op2_type op2_reg_class:$rs2),
                    (mask_type V0), GPR:$vl, sew, (XLenVT timm:$policy))>;
 
+class VPatUnaryMaskTA_E<string intrinsic_name,
+                        string inst,
+                        string kind,
+                        ValueType result_type,
+                        ValueType op2_type,
+                        ValueType mask_type,
+                        int log2sew,
+                        LMULInfo vlmul,
+                        int sew,
+                        VReg result_reg_class,
+                        VReg op2_reg_class> :
+  Pat<(result_type (!cast<Intrinsic>(intrinsic_name#"_mask")
+                   (result_type result_reg_class:$merge),
+                   (op2_type op2_reg_class:$rs2),
+                   (mask_type V0),
+                   VLOpFrag, (XLenVT timm:$policy))),
+                   (!cast<Instruction>(inst#"_"#kind#"_"#vlmul.MX#"_E"#sew#"_MASK")
+                   (result_type result_reg_class:$merge),
+                   (op2_type op2_reg_class:$rs2),
+                   (mask_type V0), GPR:$vl, log2sew, (XLenVT timm:$policy))>;
+
 class VPatMaskUnaryNoMask<string intrinsic_name,
                           string inst,
                           MTypeInfo mti> :
@@ -4336,6 +4400,23 @@ multiclass VPatUnaryV_V<string intrinsic, string instruction,
   }
 }
 
+multiclass VPatUnaryV_V_E<string intrinsic, string instruction,
+                        list<VTypeInfo> vtilist> {
+  foreach vti = vtilist in {
+    def : VPatUnaryNoMask_E<intrinsic, instruction, "V",
+                            vti.Vector, vti.Vector,
+                            vti.Log2SEW, vti.LMul, vti.SEW, vti.RegClass>;
+    def : VPatUnaryNoMaskTU_E<intrinsic, instruction, "V",
+                              vti.Vector, vti.Vector,
+                              vti.Log2SEW, vti.LMul, vti.SEW,
+                              vti.RegClass, vti.RegClass>;
+    def : VPatUnaryMaskTA_E<intrinsic, instruction, "V",
+                            vti.Vector, vti.Vector, vti.Mask,
+                            vti.Log2SEW, vti.LMul, vti.SEW,
+                            vti.RegClass, vti.RegClass>;
+  }
+}
+
 multiclass VPatNullaryV<string intrinsic, string instruction>
 {
   foreach vti = AllIntegerVectors in {
@@ -6292,7 +6373,7 @@ defm : VPatTernaryW_VV_VX<"int_riscv_vfwnmsac", "PseudoVFWNMSAC", AllWidenableFl
 //===----------------------------------------------------------------------===//
 // 13.8. Vector Floating-Point Square-Root Instruction
 //===----------------------------------------------------------------------===//
-defm : VPatUnaryV_V<"int_riscv_vfsqrt", "PseudoVFSQRT", AllFloatVectors>;
+defm : VPatUnaryV_V_E<"int_riscv_vfsqrt", "PseudoVFSQRT", AllFloatVectors>;
 
 //===----------------------------------------------------------------------===//
 // 13.9. Vector Floating-Point Reciprocal Square-Root Estimate Instruction
index ee75392..95170d3 100644 (file)
@@ -1006,7 +1006,7 @@ defm : VPatWidenFPNegMulSacSDNode_VV_VF<"PseudoVFWNMSAC">;
 foreach vti = AllFloatVectors in {
   // 13.8. Vector Floating-Point Square-Root Instruction
   def : Pat<(fsqrt (vti.Vector vti.RegClass:$rs2)),
-            (!cast<Instruction>("PseudoVFSQRT_V_"# vti.LMul.MX)
+            (!cast<Instruction>("PseudoVFSQRT_V_"# vti.LMul.MX#"_E"#vti.SEW)
                  vti.RegClass:$rs2, vti.AVL, vti.Log2SEW)>;
 
   // 13.12. Vector Floating-Point Sign-Injection Instructions
index 218610b..d3e9aa3 100644 (file)
@@ -1802,7 +1802,7 @@ foreach vti = AllFloatVectors in {
   // 13.8. Vector Floating-Point Square-Root Instruction
   def : Pat<(riscv_fsqrt_vl (vti.Vector vti.RegClass:$rs2), (vti.Mask V0),
                             VLOpFrag),
-            (!cast<Instruction>("PseudoVFSQRT_V_"# vti.LMul.MX #"_MASK")
+            (!cast<Instruction>("PseudoVFSQRT_V_"# vti.LMul.MX # "_E" # vti.SEW # "_MASK")
                  (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs2,
                  (vti.Mask V0), GPR:$vl, vti.Log2SEW, TA_MA)>;
 
index 6c0a04a..0161226 100644 (file)
@@ -309,7 +309,7 @@ defm "" : LMULSchedWrites<"WriteVFMulAddF">;
 defm "" : LMULSchedWritesFW<"WriteVFWMulAddV">;
 defm "" : LMULSchedWritesFW<"WriteVFWMulAddF">;
 // 13.8. Vector Floating-Point Square-Root Instruction
-defm "" : LMULSchedWrites<"WriteVFSqrtV">;
+defm "" : LMULSEWSchedWrites<"WriteVFSqrtV">;
 // 13.9. Vector Floating-Point Reciprocal Square-Root Estimate Instruction
 // 13.10. Vector Floating-Point Reciprocal Estimate Instruction
 defm "" : LMULSchedWrites<"WriteVFRecpV">;
@@ -528,7 +528,7 @@ defm "" : LMULSchedReads<"ReadVFMulAddF">;
 defm "" : LMULSchedReadsFW<"ReadVFWMulAddV">;
 defm "" : LMULSchedReadsFW<"ReadVFWMulAddF">;
 // 13.8. Vector Floating-Point Square-Root Instruction
-defm "" : LMULSchedReads<"ReadVFSqrtV">;
+defm "" : LMULSEWSchedReads<"ReadVFSqrtV">;
 // 13.9. Vector Floating-Point Reciprocal Square-Root Estimate Instruction
 // 13.10. Vector Floating-Point Reciprocal Estimate Instruction
 defm "" : LMULSchedReads<"ReadVFRecpV">;
@@ -757,7 +757,7 @@ defm "" : LMULWriteRes<"WriteVFMulAddV", []>;
 defm "" : LMULWriteRes<"WriteVFMulAddF", []>;
 defm "" : LMULWriteResFW<"WriteVFWMulAddV", []>;
 defm "" : LMULWriteResFW<"WriteVFWMulAddF", []>;
-defm "" : LMULWriteRes<"WriteVFSqrtV", []>;
+defm "" : LMULSEWWriteRes<"WriteVFSqrtV", []>;
 defm "" : LMULWriteRes<"WriteVFRecpV", []>;
 defm "" : LMULWriteRes<"WriteVFCmpV", []>;
 defm "" : LMULWriteRes<"WriteVFCmpF", []>;
@@ -907,7 +907,7 @@ defm "" : LMULReadAdvance<"ReadVFMulAddV", 0>;
 defm "" : LMULReadAdvance<"ReadVFMulAddF", 0>;
 defm "" : LMULReadAdvanceFW<"ReadVFWMulAddV", 0>;
 defm "" : LMULReadAdvanceFW<"ReadVFWMulAddF", 0>;
-defm "" : LMULReadAdvance<"ReadVFSqrtV", 0>;
+defm "" : LMULSEWReadAdvance<"ReadVFSqrtV", 0>;
 defm "" : LMULReadAdvance<"ReadVFRecpV", 0>;
 defm "" : LMULReadAdvance<"ReadVFCmpV", 0>;
 defm "" : LMULReadAdvance<"ReadVFCmpF", 0>;