From 2d58925362eb59c5b3e74019afb8e35b712913bd Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Sat, 29 Apr 2023 22:55:41 -0700 Subject: [PATCH] [LegalizeVectorOps][RISCV] Support condition code legalization for ISD::STRICT_FSETCC/FSETCCS during LegalizeVectorOps. Switch RISC-V to legalize during LegalizeVectorOps instead of LegalizeDAG. LegalizeDAG uses the OpVT for legalize action while LegalizeVectorOps uses the result VT. We really should fix that. --- .../lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp | 65 +++++++++++++++++----- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 4 +- 2 files changed, 52 insertions(+), 17 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp index 18c5fc6..9fcb75f 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -296,7 +296,17 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { if (Op.getOpcode() == ISD::STRICT_SINT_TO_FP || Op.getOpcode() == ISD::STRICT_UINT_TO_FP) ValVT = Node->getOperand(1).getValueType(); - Action = TLI.getOperationAction(Node->getOpcode(), ValVT); + if (Op.getOpcode() == ISD::STRICT_FSETCC || + Op.getOpcode() == ISD::STRICT_FSETCCS) { + MVT OpVT = Node->getOperand(1).getSimpleValueType(); + ISD::CondCode CCCode = cast(Node->getOperand(3))->get(); + Action = TLI.getCondCodeAction(CCCode, OpVT); + if (Action == TargetLowering::Legal) + Action = + TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0)); + } else { + Action = TLI.getOperationAction(Node->getOpcode(), ValVT); + } // If we're asked to expand a strict vector floating-point operation, // by default we're going to simply unroll it. That is usually the // best approach, except in the case where the resulting strict (scalar) @@ -1516,39 +1526,54 @@ void VectorLegalizer::ExpandSETCC(SDNode *Node, SmallVectorImpl &Results) { bool NeedInvert = false; bool IsVP = Node->getOpcode() == ISD::VP_SETCC; - SDLoc dl(Node); - MVT OpVT = Node->getOperand(0).getSimpleValueType(); - ISD::CondCode CCCode = cast(Node->getOperand(2))->get(); + bool IsStrict = Node->getOpcode() == ISD::STRICT_FSETCC || + Node->getOpcode() == ISD::STRICT_FSETCCS; + bool IsSignaling = Node->getOpcode() == ISD::STRICT_FSETCCS; + unsigned Offset = IsStrict ? 1 : 0; + + SDValue Chain = IsStrict ? Node->getOperand(0) : SDValue(); + SDValue LHS = Node->getOperand(0 + Offset); + SDValue RHS = Node->getOperand(1 + Offset); + SDValue CC = Node->getOperand(2 + Offset); + + MVT OpVT = LHS.getSimpleValueType(); + ISD::CondCode CCCode = cast(CC)->get(); if (TLI.getCondCodeAction(CCCode, OpVT) != TargetLowering::Expand) { + if (IsStrict) { + UnrollStrictFPOp(Node, Results); + return; + } Results.push_back(UnrollVSETCC(Node)); return; } - SDValue Chain; - SDValue LHS = Node->getOperand(0); - SDValue RHS = Node->getOperand(1); - SDValue CC = Node->getOperand(2); SDValue Mask, EVL; if (IsVP) { - Mask = Node->getOperand(3); - EVL = Node->getOperand(4); + Mask = Node->getOperand(3 + Offset); + EVL = Node->getOperand(4 + Offset); } + SDLoc dl(Node); bool Legalized = TLI.LegalizeSetCCCondCode(DAG, Node->getValueType(0), LHS, RHS, CC, Mask, - EVL, NeedInvert, dl, Chain); + EVL, NeedInvert, dl, Chain, IsSignaling); if (Legalized) { // If we expanded the SETCC by swapping LHS and RHS, or by inverting the // condition code, create a new SETCC node. if (CC.getNode()) { - if (!IsVP) - LHS = DAG.getNode(ISD::SETCC, dl, Node->getValueType(0), LHS, RHS, CC, - Node->getFlags()); - else + if (IsStrict) { + LHS = DAG.getNode(Node->getOpcode(), dl, Node->getVTList(), + {Chain, LHS, RHS, CC}, Node->getFlags()); + Chain = LHS.getValue(1); + } else if (IsVP) { LHS = DAG.getNode(ISD::VP_SETCC, dl, Node->getValueType(0), {LHS, RHS, CC, Mask, EVL}, Node->getFlags()); + } else { + LHS = DAG.getNode(ISD::SETCC, dl, Node->getValueType(0), LHS, RHS, CC, + Node->getFlags()); + } } // If we expanded the SETCC by inverting the condition code, then wrap @@ -1560,6 +1585,8 @@ void VectorLegalizer::ExpandSETCC(SDNode *Node, LHS = DAG.getVPLogicalNOT(dl, LHS, Mask, EVL, LHS->getValueType(0)); } } else { + assert(!IsStrict && "Don't know how to expand for strict nodes."); + // Otherwise, SETCC for the given comparison type must be completely // illegal; expand it into a SELECT_CC. EVT VT = Node->getValueType(0); @@ -1571,6 +1598,8 @@ void VectorLegalizer::ExpandSETCC(SDNode *Node, } Results.push_back(LHS); + if (IsStrict) + Results.push_back(Chain); } void VectorLegalizer::ExpandUADDSUBO(SDNode *Node, @@ -1618,6 +1647,12 @@ void VectorLegalizer::ExpandStrictFPOp(SDNode *Node, return; } + if (Node->getOpcode() == ISD::STRICT_FSETCC || + Node->getOpcode() == ISD::STRICT_FSETCCS) { + ExpandSETCC(Node, Results); + return; + } + UnrollStrictFPOp(Node, Results); } diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 4ce687c..84dceb5 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -635,7 +635,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECTOR_REVERSE, VT, Custom); - setOperationAction({ISD::STRICT_FSETCC, ISD::STRICT_FSETCCS}, VT, Legal); + setOperationAction({ISD::STRICT_FSETCC, ISD::STRICT_FSETCCS}, VT, Custom); setOperationPromotedToType( ISD::VECTOR_SPLICE, VT, @@ -898,7 +898,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, if (VT.getVectorElementType() == MVT::i1) setOperationAction({ISD::STRICT_FSETCC, ISD::STRICT_FSETCCS}, VT, - Legal); + Custom); setOperationAction(ISD::SELECT, VT, Custom); -- 2.7.4