From 9239d3a3eaf278ecf36376760b21e49512de6ac6 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Mon, 29 May 2023 19:44:43 -0700 Subject: [PATCH] [RISCV] Teach performCombineVMergeAndVOps to handle instructions FMA instructions. Previously we only handled instructions with merge ops that were also masked. This patch supports instructions with merge ops that aren't masked, like FMA. I'm only folding into a TU vmerge for now. Supporting TA vmerge shouldn't be much more work, but we need to make sure we get the policy operand for the result correct. And of course we need more tests. Reviewed By: fakepaper56, frasercrmck Differential Revision: https://reviews.llvm.org/D151596 --- llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp | 42 +++++++---- llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll | 105 ++++++++++++++++++++++++++++ llvm/test/CodeGen/RISCV/rvv/vfwmsac-vp.ll | 35 ++++++++++ 3 files changed, 169 insertions(+), 13 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index e4dd7ec9..8981e4e 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -3249,18 +3249,40 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N, bool IsTA) { uint64_t TrueTSFlags = TII->get(TrueOpc).TSFlags; bool HasMergeOp = RISCVII::hasMergeOp(TrueTSFlags); + bool IsMasked = false; + const RISCV::RISCVMaskedPseudoInfo *Info = + RISCV::lookupMaskedIntrinsicByUnmaskedTA(TrueOpc); + if (!Info && HasMergeOp) { + Info = RISCV::getMaskedPseudoInfo(TrueOpc); + IsMasked = true; + } + + if (!Info) + return false; + if (HasMergeOp) { // The vmerge instruction must be TU. + // FIXME: This could be relaxed, but we need to handle the policy for the + // resulting op correctly. if (IsTA) return false; - SDValue MergeOpN = N->getOperand(0); SDValue MergeOpTrue = True->getOperand(0); // Both the vmerge instruction and the True instruction must have the same - // merge operand. The vmerge instruction must have an all 1s mask since - // we're going to keep the mask from the True instruction. + // merge operand. + if (False != MergeOpTrue) + return false; + } + + if (IsMasked) { + assert(HasMergeOp && "Expected merge op"); + // The vmerge instruction must be TU. + if (IsTA) + return false; + // The vmerge instruction must have an all 1s mask since we're going to keep + // the mask from the True instruction. // FIXME: Support mask agnostic True instruction which would have an // undef merge operand. - if (MergeOpN != MergeOpTrue || !usesAllOnesMask(N, /* MaskOpIdx */ 3)) + if (!usesAllOnesMask(N, /* MaskOpIdx */ 3)) return false; } @@ -3269,13 +3291,6 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N, bool IsTA) { if (TII->get(TrueOpc).hasUnmodeledSideEffects()) return false; - const RISCV::RISCVMaskedPseudoInfo *Info = - HasMergeOp ? RISCV::getMaskedPseudoInfo(TrueOpc) - : RISCV::lookupMaskedIntrinsicByUnmaskedTA(TrueOpc); - - if (!Info) - return false; - // The last operand of a masked instruction may be glued. bool HasGlueOp = True->getGluedNode() != nullptr; @@ -3324,14 +3339,15 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N, bool IsTA) { "Expected instructions with mask have merge operand."); SmallVector Ops; - if (HasMergeOp) { + if (IsMasked) { Ops.append(True->op_begin(), True->op_begin() + TrueVLIndex); Ops.append({VL, /* SEW */ True.getOperand(TrueVLIndex + 1)}); Ops.push_back( CurDAG->getTargetConstant(Policy, DL, Subtarget->getXLenVT())); Ops.append(True->op_begin() + TrueVLIndex + 3, True->op_end()); } else { - Ops.push_back(False); + if (!HasMergeOp) + Ops.push_back(False); Ops.append(True->op_begin(), True->op_begin() + TrueVLIndex); Ops.append({Mask, VL, /* SEW */ True.getOperand(TrueVLIndex + 1)}); Ops.push_back( diff --git a/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll index 9586f62..330eb82 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll @@ -7,6 +7,7 @@ declare @llvm.vp.fma.nxv1f32(, , , , i32) declare @llvm.vp.fneg.nxv1f32(, , i32) declare @llvm.vp.fpext.nxv1f32.nxv1f16(, , i32) +declare @llvm.vp.merge.nxv1f32(, , , i32) define @vfmacc_vv_nxv1f32( %a, %b, %c, %m, i32 zeroext %evl) { ; CHECK-LABEL: vfmacc_vv_nxv1f32: @@ -36,6 +37,56 @@ define @vfmacc_vv_nxv1f32_unmasked( %a, ret %v } +define @vfmacc_vv_nxv1f32_tu( %a, %b, %c, %m, i32 zeroext %evl) { +; CHECK-LABEL: vfmacc_vv_nxv1f32_tu: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf4, tu, mu +; CHECK-NEXT: vfwmacc.vv v10, v8, v9, v0.t +; CHECK-NEXT: vmv1r.v v8, v10 +; CHECK-NEXT: ret + %splat = insertelement poison, i1 -1, i32 0 + %allones = shufflevector %splat, poison, zeroinitializer + %aext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %a, %allones, i32 %evl) + %bext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %b, %allones, i32 %evl) + %v = call @llvm.vp.fma.nxv1f32( %aext, %bext, %c, %allones, i32 %evl) + %u = call @llvm.vp.merge.nxv1f32( %m, %v, %c, i32 %evl) + ret %u +} + +; FIXME: Support this case? +define @vfmacc_vv_nxv1f32_masked__tu( %a, %b, %c, %m, i32 zeroext %evl) { +; CHECK-LABEL: vfmacc_vv_nxv1f32_masked__tu: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf4, ta, ma +; CHECK-NEXT: vmv1r.v v11, v10 +; CHECK-NEXT: vfwmacc.vv v11, v8, v9, v0.t +; CHECK-NEXT: vsetvli zero, zero, e32, mf2, tu, ma +; CHECK-NEXT: vmerge.vvm v10, v10, v11, v0 +; CHECK-NEXT: vmv1r.v v8, v10 +; CHECK-NEXT: ret + %aext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %a, %m, i32 %evl) + %bext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %b, %m, i32 %evl) + %v = call @llvm.vp.fma.nxv1f32( %aext, %bext, %c, %m, i32 %evl) + %u = call @llvm.vp.merge.nxv1f32( %m, %v, %c, i32 %evl) + ret %u +} + +define @vfmacc_vv_nxv1f32_unmasked_tu( %a, %b, %c, i32 zeroext %evl) { +; CHECK-LABEL: vfmacc_vv_nxv1f32_unmasked_tu: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf4, tu, ma +; CHECK-NEXT: vfwmacc.vv v10, v8, v9 +; CHECK-NEXT: vmv1r.v v8, v10 +; CHECK-NEXT: ret + %splat = insertelement poison, i1 -1, i32 0 + %allones = shufflevector %splat, poison, zeroinitializer + %aext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %a, %allones, i32 %evl) + %bext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %b, %allones, i32 %evl) + %v = call @llvm.vp.fma.nxv1f32( %aext, %bext, %c, %allones, i32 %evl) + %u = call @llvm.vp.merge.nxv1f32( %allones, %v, %c, i32 %evl) + ret %u +} + define @vfmacc_vf_nxv1f32( %va, half %b, %vc, %m, i32 zeroext %evl) { ; CHECK-LABEL: vfmacc_vf_nxv1f32: ; CHECK: # %bb.0: @@ -83,6 +134,60 @@ define @vfmacc_vf_nxv1f32_unmasked( %va, ret %v } +define @vfmacc_vf_nxv1f32_tu( %va, half %b, %vc, %m, i32 zeroext %evl) { +; CHECK-LABEL: vfmacc_vf_nxv1f32_tu: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf4, tu, mu +; CHECK-NEXT: vfwmacc.vf v9, fa0, v8, v0.t +; CHECK-NEXT: vmv1r.v v8, v9 +; CHECK-NEXT: ret + %elt.head = insertelement poison, half %b, i32 0 + %vb = shufflevector %elt.head, poison, zeroinitializer + %splat = insertelement poison, i1 -1, i32 0 + %allones = shufflevector %splat, poison, zeroinitializer + %vaext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %va, %allones, i32 %evl) + %vbext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %vb, %allones, i32 %evl) + %v = call @llvm.vp.fma.nxv1f32( %vaext, %vbext, %vc, %allones, i32 %evl) + %u = call @llvm.vp.merge.nxv1f32( %m, %v, %vc, i32 %evl) + ret %u +} + +define @vfmacc_vf_nxv1f32_commute_tu( %va, half %b, %vc, %m, i32 zeroext %evl) { +; CHECK-LABEL: vfmacc_vf_nxv1f32_commute_tu: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf4, tu, mu +; CHECK-NEXT: vfwmacc.vf v9, fa0, v8, v0.t +; CHECK-NEXT: vmv1r.v v8, v9 +; CHECK-NEXT: ret + %elt.head = insertelement poison, half %b, i32 0 + %vb = shufflevector %elt.head, poison, zeroinitializer + %splat = insertelement poison, i1 -1, i32 0 + %allones = shufflevector %splat, poison, zeroinitializer + %vaext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %va, %allones, i32 %evl) + %vbext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %vb, %allones, i32 %evl) + %v = call @llvm.vp.fma.nxv1f32( %vbext, %vaext, %vc, %allones, i32 %evl) + %u = call @llvm.vp.merge.nxv1f32( %m, %v, %vc, i32 %evl) + ret %u +} + +define @vfmacc_vf_nxv1f32_unmasked_tu( %va, half %b, %vc, i32 zeroext %evl) { +; CHECK-LABEL: vfmacc_vf_nxv1f32_unmasked_tu: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf4, tu, ma +; CHECK-NEXT: vfwmacc.vf v9, fa0, v8 +; CHECK-NEXT: vmv1r.v v8, v9 +; CHECK-NEXT: ret + %elt.head = insertelement poison, half %b, i32 0 + %vb = shufflevector %elt.head, poison, zeroinitializer + %splat = insertelement poison, i1 -1, i32 0 + %allones = shufflevector %splat, poison, zeroinitializer + %vaext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %va, %allones, i32 %evl) + %vbext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %vb, %allones, i32 %evl) + %v = call @llvm.vp.fma.nxv1f32( %vaext, %vbext, %vc, %allones, i32 %evl) + %u = call @llvm.vp.merge.nxv1f32( %allones, %v, %vc, i32 %evl) + ret %u +} + declare @llvm.vp.fma.nxv2f32(, , , , i32) declare @llvm.vp.fneg.nxv2f32(, , i32) declare @llvm.vp.merge.nxv2f32(, , , i32) diff --git a/llvm/test/CodeGen/RISCV/rvv/vfwmsac-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vfwmsac-vp.ll index 578caa3..b27a1e0 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vfwmsac-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vfwmsac-vp.ll @@ -7,6 +7,7 @@ declare @llvm.vp.fma.nxv1f32(, , , , i32) declare @llvm.vp.fneg.nxv1f32(, , i32) declare @llvm.vp.fpext.nxv1f32.nxv1f16(, , i32) +declare @llvm.vp.merge.nxv1f32(, , , i32) define @vmfsac_vv_nxv1f32( %a, %b, %c, %m, i32 zeroext %evl) { ; CHECK-LABEL: vmfsac_vv_nxv1f32: @@ -38,6 +39,40 @@ define @vmfsac_vv_nxv1f32_unmasked( %a, ret %v } +define @vmfsac_vv_nxv1f32_tu( %a, %b, %c, %m, i32 zeroext %evl) { +; CHECK-LABEL: vmfsac_vv_nxv1f32_tu: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf4, tu, mu +; CHECK-NEXT: vfwmsac.vv v10, v8, v9, v0.t +; CHECK-NEXT: vmv1r.v v8, v10 +; CHECK-NEXT: ret + %splat = insertelement poison, i1 -1, i32 0 + %allones = shufflevector %splat, poison, zeroinitializer + %aext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %a, %allones, i32 %evl) + %bext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %b, %allones, i32 %evl) + %negc = call @llvm.vp.fneg.nxv1f32( %c, %allones, i32 %evl) + %v = call @llvm.vp.fma.nxv1f32( %aext, %bext, %negc, %allones, i32 %evl) + %u = call @llvm.vp.merge.nxv1f32( %m, %v, %c, i32 %evl) + ret %u +} + +define @vmfsac_vv_nxv1f32_unmasked_tu( %a, %b, %c, i32 zeroext %evl) { +; CHECK-LABEL: vmfsac_vv_nxv1f32_unmasked_tu: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf4, tu, ma +; CHECK-NEXT: vfwmsac.vv v10, v8, v9 +; CHECK-NEXT: vmv1r.v v8, v10 +; CHECK-NEXT: ret + %splat = insertelement poison, i1 -1, i32 0 + %allones = shufflevector %splat, poison, zeroinitializer + %aext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %a, %allones, i32 %evl) + %bext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %b, %allones, i32 %evl) + %negc = call @llvm.vp.fneg.nxv1f32( %c, %allones, i32 %evl) + %v = call @llvm.vp.fma.nxv1f32( %aext, %bext, %negc, %allones, i32 %evl) + %u = call @llvm.vp.merge.nxv1f32( %allones, %v, %c, i32 %evl) + ret %u +} + define @vmfsac_vf_nxv1f32( %a, half %b, %c, %m, i32 zeroext %evl) { ; CHECK-LABEL: vmfsac_vf_nxv1f32: ; CHECK: # %bb.0: -- 2.7.4