From 4852f26acde16ad85845a22a86eeae7dcf3287db Mon Sep 17 00:00:00 2001 From: Quentin Colombet Date: Thu, 22 Sep 2022 04:16:08 +0000 Subject: [PATCH] [RISCV][ISel] Refactor the formation of VW operations This patch centralizes all the combines of add|sub|mul with extended operands in one "framework". The rationale for this change is to offer a one-stop-shop for all these transformations so that, in the future, it is easier to make combine decisions for a web of instructions (i.e., instructions connected through s|zext operands). Technically this patch is not NFC because the new version is more powerful than the previous version. In particular, it diverges in two cases: - VWMULSU can now also be produced from `mul(splat, zext)`, whereas previously only `mul(sext, splat)` were supported when `splat`s were involved. (As demonstrated in rvv/fixed-vectors-vwmulsu.ll) - VWSUB(U) can now also be produced from `sub(splat, ext)`, whereas previously only `sub(ext, splat)` were supported when `splat`s were involved. (As demonstrated in rvv/fixed-vectors-vwsub.ll) If we wanted, we could block these transformations to make this patch really NFC. For instance, we could do something similar to `AllowSplatInVW_W`, which prevents the combines to form vw(add|sub)(u)_w when the RHS is a splat. Regarding the "framework" itself, the bulk of the patch is some boilderplate code that abstracts away the actual extensions that are present in the DAG. This allows us to handle `vwadd_w(ext a, b)` as if it was a regular `add(ext a, ext b)`. Since the node `ext b` doesn't actually exist in the DAG, we have a bunch of methods (all in the NodeExtensionHelper class) that fake all that for us. The other half of the change is around `CombineToTry` and `CombineResult`. These helper structures respectively: - Represent the kind of combines that can be applied to a node, and - Store what needs to happen to do that combine. This can be viewed as a two step approach: - First, check if a pattern applies, and - Second apply it. The checks and the materialization of the combines are decoupled so that in the future we can perform several checks and do all the related applies in one go. Differential Revision: https://reviews.llvm.org/D134703 --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 705 +++++++++++++++------ llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll | 26 + .../CodeGen/RISCV/rvv/fixed-vectors-vwmulsu.ll | 9 +- llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsub.ll | 45 +- .../test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll | 45 +- 5 files changed, 578 insertions(+), 252 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 1544724..f4f74eb 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -33,6 +33,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicsRISCV.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" @@ -45,6 +46,12 @@ using namespace llvm; STATISTIC(NumTailCalls, "Number of tail calls"); +static cl::opt + AllowSplatInVW_W(DEBUG_TYPE "-form-vw-w-with-splat", cl::Hidden, + cl::desc("Allow the formation of VW_W operations (e.g., " + "VWADD_W) with splat constants"), + cl::init(false)); + RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, const RISCVSubtarget &STI) : TargetLowering(TM), Subtarget(STI) { @@ -8204,228 +8211,548 @@ performSIGN_EXTEND_INREGCombine(SDNode *N, SelectionDAG &DAG, return SDValue(); } -// Try to form vwadd(u).wv/wx or vwsub(u).wv/wx. It might later be optimized to -// vwadd(u).vv/vx or vwsub(u).vv/vx. -static SDValue combineADDSUB_VLToVWADDSUB_VL(SDNode *N, SelectionDAG &DAG, - bool Commute = false) { - assert((N->getOpcode() == RISCVISD::ADD_VL || - N->getOpcode() == RISCVISD::SUB_VL) && - "Unexpected opcode"); - bool IsAdd = N->getOpcode() == RISCVISD::ADD_VL; - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); - if (Commute) - std::swap(Op0, Op1); +namespace { +// Forward declaration of the structure holding the necessary information to +// apply a combine. +struct CombineResult; - MVT VT = N->getSimpleValueType(0); +/// Helper class for folding sign/zero extensions. +/// In particular, this class is used for the following combines: +/// add_vl -> vwadd(u) | vwadd(u)_w +/// sub_vl -> vwsub(u) | vwsub(u)_w +/// mul_vl -> vwmul(u) | vwmul_su +/// +/// An object of this class represents an operand of the operation we want to +/// combine. +/// E.g., when trying to combine `mul_vl a, b`, we will have one instance of +/// NodeExtensionHelper for `a` and one for `b`. +/// +/// This class abstracts away how the extension is materialized and +/// how its Mask, VL, number of users affect the combines. +/// +/// In particular: +/// - VWADD_W is conceptually == add(op0, sext(op1)) +/// - VWADDU_W == add(op0, zext(op1)) +/// - VWSUB_W == sub(op0, sext(op1)) +/// - VWSUBU_W == sub(op0, zext(op1)) +/// +/// And VMV_V_X_VL, depending on the value, is conceptually equivalent to +/// zext|sext(smaller_value). +struct NodeExtensionHelper { + /// Records if this operand is like being zero extended. + bool SupportsZExt; + /// Records if this operand is like being sign extended. + /// Note: SupportsZExt and SupportsSExt are not mutually exclusive. For + /// instance, a splat constant (e.g., 3), would support being both sign and + /// zero extended. + bool SupportsSExt; + /// This boolean captures whether we care if this operand would still be + /// around after the folding happens. + bool EnforceOneUse; + /// Records if this operand's mask needs to match the mask of the operation + /// that it will fold into. + bool CheckMask; + /// Value of the Mask for this operand. + /// It may be SDValue(). + SDValue Mask; + /// Value of the vector length operand. + /// It may be SDValue(). + SDValue VL; + /// Original value that this NodeExtensionHelper represents. + SDValue OrigOperand; + + /// Get the value feeding the extension or the value itself. + /// E.g., for zext(a), this would return a. + SDValue getSource() const { + switch (OrigOperand.getOpcode()) { + case RISCVISD::VSEXT_VL: + case RISCVISD::VZEXT_VL: + return OrigOperand.getOperand(0); + default: + return OrigOperand; + } + } + + /// Check if this instance represents a splat. + bool isSplat() const { + return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL; + } + + /// Get or create a value that can feed \p Root with the given \p ExtOpc. + /// If \p ExtOpc is None, this returns the source of this operand. + /// \see ::getSource(). + SDValue getOrCreateExtendedOp(const SDNode *Root, SelectionDAG &DAG, + Optional ExtOpc) const { + SDValue Source = getSource(); + if (!ExtOpc) + return Source; + + MVT NarrowVT = getNarrowType(Root); + // If we need an extension, we should be changing the type. + assert(Source.getValueType() != NarrowVT && "Needless extension"); + SDLoc DL(Root); + auto [Mask, VL] = getMaskAndVL(Root); + switch (OrigOperand.getOpcode()) { + case RISCVISD::VSEXT_VL: + case RISCVISD::VZEXT_VL: + return DAG.getNode(*ExtOpc, DL, NarrowVT, Source, Mask, VL); + case RISCVISD::VMV_V_X_VL: + return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT, + DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL); + default: + // Other opcodes can only come from the original LHS of VW(ADD|SUB)_W_VL + // and that operand should already have the right NarrowVT so no + // extension should be required at this point. + llvm_unreachable("Unsupported opcode"); + } + } - // Determine the narrow size for a widening add/sub. - unsigned NarrowSize = VT.getScalarSizeInBits() / 2; - MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize), - VT.getVectorElementCount()); + /// Helper function to get the narrow type for \p Root. + /// The narrow type is the type of \p Root where we divided the size of each + /// element by 2. E.g., if Root's type <2xi16> -> narrow type <2xi8>. + /// \pre The size of the type of the elements of Root must be a multiple of 2 + /// and be greater than 16. + static MVT getNarrowType(const SDNode *Root) { + MVT VT = Root->getSimpleValueType(0); - SDValue Merge = N->getOperand(2); - SDValue Mask = N->getOperand(3); - SDValue VL = N->getOperand(4); + // Determine the narrow size. + unsigned NarrowSize = VT.getScalarSizeInBits() / 2; + assert(NarrowSize >= 8 && "Trying to extend something we can't represent"); + MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize), + VT.getVectorElementCount()); + return NarrowVT; + } - SDLoc DL(N); + /// Return the opcode required to materialize the folding of the sign + /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for + /// both operands for \p Opcode. + /// Put differently, get the opcode to materialize: + /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b) + /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b) + /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()). + static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) { + switch (Opcode) { + case RISCVISD::ADD_VL: + case RISCVISD::VWADD_W_VL: + case RISCVISD::VWADDU_W_VL: + return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL; + case RISCVISD::MUL_VL: + return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL; + case RISCVISD::SUB_VL: + case RISCVISD::VWSUB_W_VL: + case RISCVISD::VWSUBU_W_VL: + return IsSExt ? RISCVISD::VWSUB_VL : RISCVISD::VWSUBU_VL; + default: + llvm_unreachable("Unexpected opcode"); + } + } - // If the RHS is a sext or zext, we can form a widening op. - if ((Op1.getOpcode() == RISCVISD::VZEXT_VL || - Op1.getOpcode() == RISCVISD::VSEXT_VL) && - Op1.hasOneUse() && Op1.getOperand(1) == Mask && Op1.getOperand(2) == VL) { - unsigned ExtOpc = Op1.getOpcode(); - Op1 = Op1.getOperand(0); - // Re-introduce narrower extends if needed. - if (Op1.getValueType() != NarrowVT) - Op1 = DAG.getNode(ExtOpc, DL, NarrowVT, Op1, Mask, VL); - - unsigned WOpc; - if (ExtOpc == RISCVISD::VSEXT_VL) - WOpc = IsAdd ? RISCVISD::VWADD_W_VL : RISCVISD::VWSUB_W_VL; - else - WOpc = IsAdd ? RISCVISD::VWADDU_W_VL : RISCVISD::VWSUBU_W_VL; + /// Get the opcode to materialize \p Opcode(sext(a), zext(b)) -> + /// newOpcode(a, b). + static unsigned getSUOpcode(unsigned Opcode) { + assert(Opcode == RISCVISD::MUL_VL && "SU is only supported for MUL"); + return RISCVISD::VWMULSU_VL; + } - return DAG.getNode(WOpc, DL, VT, Op0, Op1, Merge, Mask, VL); + /// Get the opcode to materialize \p Opcode(a, s|zext(b)) -> + /// newOpcode(a, b). + static unsigned getWOpcode(unsigned Opcode, bool IsSExt) { + switch (Opcode) { + case RISCVISD::ADD_VL: + return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL; + case RISCVISD::SUB_VL: + return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL; + default: + llvm_unreachable("Unexpected opcode"); + } } - // FIXME: Is it useful to form a vwadd.wx or vwsub.wx if it removes a scalar - // sext/zext? + using CombineToTry = std::function( + SDNode * /*Root*/, const NodeExtensionHelper & /*LHS*/, + const NodeExtensionHelper & /*RHS*/)>; - return SDValue(); -} + /// Check if this node needs to be fully folded or extended for all users. + bool needToPromoteOtherUsers() const { return EnforceOneUse; } -// Try to convert vwadd(u).wv/wx or vwsub(u).wv/wx to vwadd(u).vv/vx or -// vwsub(u).vv/vx. -static SDValue combineVWADD_W_VL_VWSUB_W_VL(SDNode *N, SelectionDAG &DAG) { - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); - SDValue Merge = N->getOperand(2); - SDValue Mask = N->getOperand(3); - SDValue VL = N->getOperand(4); + /// Helper method to set the various fields of this struct based on the + /// type of \p Root. + void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG) { + SupportsZExt = false; + SupportsSExt = false; + EnforceOneUse = true; + CheckMask = true; + switch (OrigOperand.getOpcode()) { + case RISCVISD::VZEXT_VL: + SupportsZExt = true; + Mask = OrigOperand.getOperand(1); + VL = OrigOperand.getOperand(2); + break; + case RISCVISD::VSEXT_VL: + SupportsSExt = true; + Mask = OrigOperand.getOperand(1); + VL = OrigOperand.getOperand(2); + break; + case RISCVISD::VMV_V_X_VL: { + // Historically, we didn't care about splat values not disappearing during + // combines. + EnforceOneUse = false; + CheckMask = false; + VL = OrigOperand.getOperand(2); - MVT VT = N->getSimpleValueType(0); - MVT NarrowVT = Op1.getSimpleValueType(); - unsigned NarrowSize = NarrowVT.getScalarSizeInBits(); + // The operand is a splat of a scalar. - unsigned VOpc; - switch (N->getOpcode()) { - default: llvm_unreachable("Unexpected opcode"); - case RISCVISD::VWADD_W_VL: VOpc = RISCVISD::VWADD_VL; break; - case RISCVISD::VWSUB_W_VL: VOpc = RISCVISD::VWSUB_VL; break; - case RISCVISD::VWADDU_W_VL: VOpc = RISCVISD::VWADDU_VL; break; - case RISCVISD::VWSUBU_W_VL: VOpc = RISCVISD::VWSUBU_VL; break; - } + // The pasthru must be undef for tail agnostic. + if (!OrigOperand.getOperand(0).isUndef()) + break; - bool IsSigned = N->getOpcode() == RISCVISD::VWADD_W_VL || - N->getOpcode() == RISCVISD::VWSUB_W_VL; + // Get the scalar value. + SDValue Op = OrigOperand.getOperand(1); + + // See if we have enough sign bits or zero bits in the scalar to use a + // widening opcode by splatting to smaller element size. + MVT VT = Root->getSimpleValueType(0); + unsigned EltBits = VT.getScalarSizeInBits(); + unsigned ScalarBits = Op.getValueSizeInBits(); + // Make sure we're getting all element bits from the scalar register. + // FIXME: Support implicit sign extension of vmv.v.x? + if (ScalarBits < EltBits) + break; - SDLoc DL(N); + unsigned NarrowSize = VT.getScalarSizeInBits() / 2; + // If the narrow type cannot be expressed with a legal VMV, + // this is not a valid candidate. + if (NarrowSize < 8) + break; - // If the LHS is a sext or zext, we can narrow this op to the same size as - // the RHS. - if (((Op0.getOpcode() == RISCVISD::VZEXT_VL && !IsSigned) || - (Op0.getOpcode() == RISCVISD::VSEXT_VL && IsSigned)) && - Op0.hasOneUse() && Op0.getOperand(1) == Mask && Op0.getOperand(2) == VL) { - unsigned ExtOpc = Op0.getOpcode(); - Op0 = Op0.getOperand(0); - // Re-introduce narrower extends if needed. - if (Op0.getValueType() != NarrowVT) - Op0 = DAG.getNode(ExtOpc, DL, NarrowVT, Op0, Mask, VL); - return DAG.getNode(VOpc, DL, VT, Op0, Op1, Merge, Mask, VL); - } - - bool IsAdd = N->getOpcode() == RISCVISD::VWADD_W_VL || - N->getOpcode() == RISCVISD::VWADDU_W_VL; - - // Look for splats on the left hand side of a vwadd(u).wv. We might be able - // to commute and use a vwadd(u).vx instead. - if (IsAdd && Op0.getOpcode() == RISCVISD::VMV_V_X_VL && - Op0.getOperand(0).isUndef() && Op0.getOperand(2) == VL) { - Op0 = Op0.getOperand(1); - - // See if have enough sign bits or zero bits in the scalar to use a - // widening add/sub by splatting to smaller element size. - unsigned EltBits = VT.getScalarSizeInBits(); - unsigned ScalarBits = Op0.getValueSizeInBits(); - // Make sure we're getting all element bits from the scalar register. - // FIXME: Support implicit sign extension of vmv.v.x? - if (ScalarBits < EltBits) - return SDValue(); + if (DAG.ComputeMaxSignificantBits(Op) <= NarrowSize) + SupportsSExt = true; + if (DAG.MaskedValueIsZero(Op, + APInt::getBitsSetFrom(ScalarBits, NarrowSize))) + SupportsZExt = true; + break; + } + default: + break; + } + } - if (IsSigned) { - if (DAG.ComputeMaxSignificantBits(Op0) > NarrowSize) - return SDValue(); - } else { - if (!DAG.MaskedValueIsZero(Op0, - APInt::getBitsSetFrom(ScalarBits, NarrowSize))) - return SDValue(); + /// Check if \p Root supports any extension folding combines. + static bool isSupportedRoot(const SDNode *Root) { + switch (Root->getOpcode()) { + case RISCVISD::ADD_VL: + case RISCVISD::MUL_VL: + case RISCVISD::VWADD_W_VL: + case RISCVISD::VWADDU_W_VL: + case RISCVISD::SUB_VL: + case RISCVISD::VWSUB_W_VL: + case RISCVISD::VWSUBU_W_VL: + return true; + default: + return false; } + } - Op0 = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT, - DAG.getUNDEF(NarrowVT), Op0, VL); - return DAG.getNode(VOpc, DL, VT, Op1, Op0, Merge, Mask, VL); + /// Build a NodeExtensionHelper for \p Root.getOperand(\p OperandIdx). + NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG) { + assert(isSupportedRoot(Root) && "Trying to build an helper with an " + "unsupported root"); + assert(OperandIdx < 2 && "Requesting something else than LHS or RHS"); + OrigOperand = Root->getOperand(OperandIdx); + + unsigned Opc = Root->getOpcode(); + switch (Opc) { + // We consider VW(U)_W(LHS, RHS) as if they were + // (LHS, S|ZEXT(RHS)) + case RISCVISD::VWADD_W_VL: + case RISCVISD::VWADDU_W_VL: + case RISCVISD::VWSUB_W_VL: + case RISCVISD::VWSUBU_W_VL: + if (OperandIdx == 1) { + SupportsZExt = + Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL; + SupportsSExt = !SupportsZExt; + std::tie(Mask, VL) = getMaskAndVL(Root); + // There's no existing extension here, so we don't have to worry about + // making sure it gets removed. + EnforceOneUse = false; + break; + } + [[fallthrough]]; + default: + fillUpExtensionSupport(Root, DAG); + break; + } } - return SDValue(); -} + /// Check if this operand is compatible with the given vector length \p VL. + bool isVLCompatible(SDValue VL) const { return this->VL && this->VL == VL; } -// Try to form VWMUL, VWMULU or VWMULSU. -// TODO: Support VWMULSU.vx with a sign extend Op and a splat of scalar Op. -static SDValue combineMUL_VLToVWMUL_VL(SDNode *N, SelectionDAG &DAG, - bool Commute) { - assert(N->getOpcode() == RISCVISD::MUL_VL && "Unexpected opcode"); - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); - if (Commute) - std::swap(Op0, Op1); - - bool IsSignExt = Op0.getOpcode() == RISCVISD::VSEXT_VL; - bool IsZeroExt = Op0.getOpcode() == RISCVISD::VZEXT_VL; - bool IsVWMULSU = IsSignExt && Op1.getOpcode() == RISCVISD::VZEXT_VL; - if ((!IsSignExt && !IsZeroExt) || !Op0.hasOneUse()) - return SDValue(); + /// Check if this operand is compatible with the given \p Mask. + bool isMaskCompatible(SDValue Mask) const { + return !CheckMask || (this->Mask && this->Mask == Mask); + } - SDValue Merge = N->getOperand(2); - SDValue Mask = N->getOperand(3); - SDValue VL = N->getOperand(4); + /// Helper function to get the Mask and VL from \p Root. + static std::pair getMaskAndVL(const SDNode *Root) { + assert(isSupportedRoot(Root) && "Unexpected root"); + return std::make_pair(Root->getOperand(3), Root->getOperand(4)); + } - // Make sure the mask and VL match. - if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL) - return SDValue(); + /// Check if the Mask and VL of this operand are compatible with \p Root. + bool areVLAndMaskCompatible(const SDNode *Root) const { + auto [Mask, VL] = getMaskAndVL(Root); + return isMaskCompatible(Mask) && isVLCompatible(VL); + } - MVT VT = N->getSimpleValueType(0); + /// Helper function to check if \p N is commutative with respect to the + /// foldings that are supported by this class. + static bool isCommutative(const SDNode *N) { + switch (N->getOpcode()) { + case RISCVISD::ADD_VL: + case RISCVISD::MUL_VL: + case RISCVISD::VWADD_W_VL: + case RISCVISD::VWADDU_W_VL: + return true; + case RISCVISD::SUB_VL: + case RISCVISD::VWSUB_W_VL: + case RISCVISD::VWSUBU_W_VL: + return false; + default: + llvm_unreachable("Unexpected opcode"); + } + } - // Determine the narrow size for a widening multiply. - unsigned NarrowSize = VT.getScalarSizeInBits() / 2; - MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize), - VT.getVectorElementCount()); + /// Get a list of combine to try for folding extensions in \p Root. + /// Note that each returned CombineToTry function doesn't actually modify + /// anything. Instead they produce an optional CombineResult that if not None, + /// need to be materialized for the combine to be applied. + /// \see CombineResult::materialize. + /// If the related CombineToTry function returns None, that means the combine + /// didn't match. + static SmallVector getSupportedFoldings(const SDNode *Root); +}; - SDLoc DL(N); +/// Helper structure that holds all the necessary information to materialize a +/// combine that does some extension folding. +struct CombineResult { + /// Opcode to be generated when materializing the combine. + unsigned TargetOpcode; + /// Extension opcode to be applied to the source of LHS when materializing + /// TargetOpcode. + /// \see NodeExtensionHelper::getSource(). + Optional LHSExtOpc; + /// Extension opcode to be applied to the source of RHS when materializing + /// TargetOpcode. + Optional RHSExtOpc; + /// Root of the combine. + SDNode *Root; + /// LHS of the TargetOpcode. + const NodeExtensionHelper &LHS; + /// RHS of the TargetOpcode. + const NodeExtensionHelper &RHS; + + CombineResult(unsigned TargetOpcode, SDNode *Root, + const NodeExtensionHelper &LHS, Optional SExtLHS, + const NodeExtensionHelper &RHS, Optional SExtRHS) + : TargetOpcode(TargetOpcode), Root(Root), LHS(LHS), RHS(RHS) { + MVT NarrowVT = NodeExtensionHelper::getNarrowType(Root); + if (SExtLHS && LHS.getSource().getValueType() != NarrowVT) + LHSExtOpc = *SExtLHS ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL; + if (SExtRHS && RHS.getSource().getValueType() != NarrowVT) + RHSExtOpc = *SExtRHS ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL; + } + + /// Return a value that uses TargetOpcode and that can be used to replace + /// Root. + /// The actual replacement is *not* done in that method. + SDValue materialize(SelectionDAG &DAG) const { + SDValue Mask, VL, Merge; + std::tie(Mask, VL) = NodeExtensionHelper::getMaskAndVL(Root); + Merge = Root->getOperand(2); + return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0), + LHS.getOrCreateExtendedOp(Root, DAG, LHSExtOpc), + RHS.getOrCreateExtendedOp(Root, DAG, RHSExtOpc), Merge, + Mask, VL); + } +}; - // See if the other operand is the same opcode. - if (IsVWMULSU || Op0.getOpcode() == Op1.getOpcode()) { - if (!Op1.hasOneUse()) - return SDValue(); +/// Check if \p Root follows a pattern Root(ext(LHS), ext(RHS)) +/// where `ext` is the same for both LHS and RHS (i.e., both are sext or both +/// are zext) and LHS and RHS can be folded into Root. +/// AllowSExt and AllozZExt define which form `ext` can take in this pattern. +/// +/// \note If the pattern can match with both zext and sext, the returned +/// CombineResult will feature the zext result. +/// +/// \returns None if the pattern doesn't match or a CombineResult that can be +/// used to apply the pattern. +static Optional +canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS, bool AllowSExt, + bool AllowZExt) { + assert((AllowSExt || AllowZExt) && "Forgot to set what you want?"); + if (!LHS.areVLAndMaskCompatible(Root) || !RHS.areVLAndMaskCompatible(Root)) + return None; + if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt) + return CombineResult(NodeExtensionHelper::getSameExtensionOpcode( + Root->getOpcode(), /*IsSExt=*/false), + Root, LHS, /*SExtLHS=*/false, RHS, + /*SExtRHS=*/false); + if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt) + return CombineResult(NodeExtensionHelper::getSameExtensionOpcode( + Root->getOpcode(), /*IsSExt=*/true), + Root, LHS, /*SExtLHS=*/true, RHS, + /*SExtRHS=*/true); + return None; +} - // Make sure the mask and VL match. - if (Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL) - return SDValue(); +/// Check if \p Root follows a pattern Root(ext(LHS), ext(RHS)) +/// where `ext` is the same for both LHS and RHS (i.e., both are sext or both +/// are zext) and LHS and RHS can be folded into Root. +/// +/// \returns None if the pattern doesn't match or a CombineResult that can be +/// used to apply the pattern. +static Optional +canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS) { + return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true, + /*AllowZExt=*/true); +} - Op1 = Op1.getOperand(0); - } else if (Op1.getOpcode() == RISCVISD::VMV_V_X_VL) { - // The operand is a splat of a scalar. +/// Check if \p Root follows a pattern Root(LHS, ext(RHS)) +/// +/// \returns None if the pattern doesn't match or a CombineResult that can be +/// used to apply the pattern. +static Optional canFoldToVW_W(SDNode *Root, + const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS) { + if (!RHS.areVLAndMaskCompatible(Root)) + return None; - // The pasthru must be undef for tail agnostic - if (!Op1.getOperand(0).isUndef()) - return SDValue(); - // The VL must be the same. - if (Op1.getOperand(2) != VL) - return SDValue(); + // FIXME: Is it useful to form a vwadd.wx or vwsub.wx if it removes a scalar + // sext/zext? + // Control this behavior behind an option (AllowSplatInVW_W) for testing + // purposes. + if (RHS.SupportsZExt && (!RHS.isSplat() || AllowSplatInVW_W)) + return CombineResult( + NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/false), + Root, LHS, /*SExtLHS=*/None, RHS, /*SExtRHS=*/false); + if (RHS.SupportsSExt && (!RHS.isSplat() || AllowSplatInVW_W)) + return CombineResult( + NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/true), + Root, LHS, /*SExtLHS=*/None, RHS, /*SExtRHS=*/true); + return None; +} - // Get the scalar value. - Op1 = Op1.getOperand(1); +/// Check if \p Root follows a pattern Root(sext(LHS), sext(RHS)) +/// +/// \returns None if the pattern doesn't match or a CombineResult that can be +/// used to apply the pattern. +static Optional +canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS) { + return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true, + /*AllowZExt=*/false); +} - // See if have enough sign bits or zero bits in the scalar to use a - // widening multiply by splatting to smaller element size. - unsigned EltBits = VT.getScalarSizeInBits(); - unsigned ScalarBits = Op1.getValueSizeInBits(); - // Make sure we're getting all element bits from the scalar register. - // FIXME: Support implicit sign extension of vmv.v.x? - if (ScalarBits < EltBits) - return SDValue(); +/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS)) +/// +/// \returns None if the pattern doesn't match or a CombineResult that can be +/// used to apply the pattern. +static Optional +canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS) { + return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false, + /*AllowZExt=*/true); +} - // If the LHS is a sign extend, try to use vwmul. - if (IsSignExt && DAG.ComputeMaxSignificantBits(Op1) <= NarrowSize) { - // Can use vwmul. - } else if (DAG.MaskedValueIsZero( - Op1, APInt::getBitsSetFrom(ScalarBits, NarrowSize))) { - // Scalar is zero extended, if the vector is sign extended we can use - // vwmulsu. If the vector is zero extended we can use vwmulu. - IsVWMULSU = IsSignExt; - } else - return SDValue(); +/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS)) +/// +/// \returns None if the pattern doesn't match or a CombineResult that can be +/// used to apply the pattern. +static Optional canFoldToVW_SU(SDNode *Root, + const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS) { + if (!LHS.SupportsSExt || !RHS.SupportsZExt) + return None; + if (!LHS.areVLAndMaskCompatible(Root) || !RHS.areVLAndMaskCompatible(Root)) + return None; + return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()), + Root, LHS, /*SExtLHS=*/true, RHS, /*SExtRHS=*/false); +} - Op1 = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT, - DAG.getUNDEF(NarrowVT), Op1, VL); - } else +SmallVector +NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { + SmallVector Strategies; + switch (Root->getOpcode()) { + case RISCVISD::ADD_VL: + case RISCVISD::SUB_VL: + // add|sub -> vwadd(u)|vwsub(u) + Strategies.push_back(canFoldToVWWithSameExtension); + // add|sub -> vwadd(u)_w|vwsub(u)_w + Strategies.push_back(canFoldToVW_W); + break; + case RISCVISD::MUL_VL: + // mul -> vwmul(u) + Strategies.push_back(canFoldToVWWithSameExtension); + // mul -> vwmulsu + Strategies.push_back(canFoldToVW_SU); + break; + case RISCVISD::VWADD_W_VL: + case RISCVISD::VWSUB_W_VL: + // vwadd_w|vwsub_w -> vwadd|vwsub + Strategies.push_back(canFoldToVWWithSEXT); + break; + case RISCVISD::VWADDU_W_VL: + case RISCVISD::VWSUBU_W_VL: + // vwaddu_w|vwsubu_w -> vwaddu|vwsubu + Strategies.push_back(canFoldToVWWithZEXT); + break; + default: + llvm_unreachable("Unexpected opcode"); + } + return Strategies; +} +} // End anonymous namespace. + +/// Combine a binary operation to its equivalent VW or VW_W form. +/// The supported combines are: +/// add_vl -> vwadd(u) | vwadd(u)_w +/// sub_vl -> vwsub(u) | vwsub(u)_w +/// mul_vl -> vwmul(u) | vwmul_su +/// vwadd_w(u) -> vwadd(u) +/// vwub_w(u) -> vwadd(u) +static SDValue +combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { + SelectionDAG &DAG = DCI.DAG; + + assert(NodeExtensionHelper::isSupportedRoot(N) && + "Shouldn't have called this method"); + + NodeExtensionHelper LHS(N, 0, DAG); + NodeExtensionHelper RHS(N, 1, DAG); + + if (LHS.needToPromoteOtherUsers() && !LHS.OrigOperand.hasOneUse()) return SDValue(); - Op0 = Op0.getOperand(0); + if (RHS.needToPromoteOtherUsers() && !RHS.OrigOperand.hasOneUse()) + return SDValue(); - // Re-introduce narrower extends if needed. - unsigned ExtOpc = IsSignExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL; - if (Op0.getValueType() != NarrowVT) - Op0 = DAG.getNode(ExtOpc, DL, NarrowVT, Op0, Mask, VL); - // vwmulsu requires second operand to be zero extended. - ExtOpc = IsVWMULSU ? RISCVISD::VZEXT_VL : ExtOpc; - if (Op1.getValueType() != NarrowVT) - Op1 = DAG.getNode(ExtOpc, DL, NarrowVT, Op1, Mask, VL); + SmallVector FoldingStrategies = + NodeExtensionHelper::getSupportedFoldings(N); - unsigned WMulOpc = RISCVISD::VWMULSU_VL; - if (!IsVWMULSU) - WMulOpc = IsSignExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL; - return DAG.getNode(WMulOpc, DL, VT, Op0, Op1, Merge, Mask, VL); + assert(!FoldingStrategies.empty() && "Nothing to be folded"); + for (int Attempt = 0; Attempt != 1 + NodeExtensionHelper::isCommutative(N); + ++Attempt) { + for (NodeExtensionHelper::CombineToTry FoldingStrategy : + FoldingStrategies) { + Optional Res = FoldingStrategy(N, LHS, RHS); + if (Res) + return Res->materialize(DAG); + } + std::swap(LHS, RHS); + } + return SDValue(); } // Fold @@ -9232,21 +9559,13 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, break; } case RISCVISD::ADD_VL: - if (SDValue V = combineADDSUB_VLToVWADDSUB_VL(N, DAG, /*Commute*/ false)) - return V; - return combineADDSUB_VLToVWADDSUB_VL(N, DAG, /*Commute*/ true); case RISCVISD::SUB_VL: - return combineADDSUB_VLToVWADDSUB_VL(N, DAG); case RISCVISD::VWADD_W_VL: case RISCVISD::VWADDU_W_VL: case RISCVISD::VWSUB_W_VL: case RISCVISD::VWSUBU_W_VL: - return combineVWADD_W_VL_VWSUB_W_VL(N, DAG); case RISCVISD::MUL_VL: - if (SDValue V = combineMUL_VLToVWMUL_VL(N, DAG, /*Commute*/ false)) - return V; - // Mul is commutative. - return combineMUL_VLToVWMUL_VL(N, DAG, /*Commute*/ true); + return combineBinOp_VLToVWBinOp_VL(N, DCI); case RISCVISD::VFMADD_VL: case RISCVISD::VFNMADD_VL: case RISCVISD::VFMSUB_VL: diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll index 5335fbe..8862e33 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll @@ -18,6 +18,32 @@ define <2 x i16> @vwmul_v2i16(<2 x i8>* %x, <2 x i8>* %y) { ret <2 x i16> %e } +define <2 x i16> @vwmul_v2i16_multiple_users(<2 x i8>* %x, <2 x i8>* %y, <2 x i8> *%z) { +; CHECK-LABEL: vwmul_v2i16_multiple_users: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, mu +; CHECK-NEXT: vle8.v v8, (a0) +; CHECK-NEXT: vle8.v v9, (a1) +; CHECK-NEXT: vle8.v v10, (a2) +; CHECK-NEXT: vsext.vf2 v11, v8 +; CHECK-NEXT: vsext.vf2 v8, v9 +; CHECK-NEXT: vsext.vf2 v9, v10 +; CHECK-NEXT: vmul.vv v8, v11, v8 +; CHECK-NEXT: vmul.vv v9, v11, v9 +; CHECK-NEXT: vor.vv v8, v8, v9 +; CHECK-NEXT: ret + %a = load <2 x i8>, <2 x i8>* %x + %b = load <2 x i8>, <2 x i8>* %y + %b2 = load <2 x i8>, <2 x i8>* %z + %c = sext <2 x i8> %a to <2 x i16> + %d = sext <2 x i8> %b to <2 x i16> + %d2 = sext <2 x i8> %b2 to <2 x i16> + %e = mul <2 x i16> %c, %d + %f = mul <2 x i16> %c, %d2 + %g = or <2 x i16> %e, %f + ret <2 x i16> %g +} + define <4 x i16> @vwmul_v4i16(<4 x i8>* %x, <4 x i8>* %y) { ; CHECK-LABEL: vwmul_v4i16: ; CHECK: # %bb.0: diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmulsu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmulsu.ll index 2746c8e..82a61e4 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmulsu.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmulsu.ll @@ -701,11 +701,10 @@ define <8 x i16> @vwmulsu_vx_v8i16_i8(<8 x i8>* %x, i8* %y) { define <8 x i16> @vwmulsu_vx_v8i16_i8_swap(<8 x i8>* %x, i8* %y) { ; CHECK-LABEL: vwmulsu_vx_v8i16_i8_swap: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, mu -; CHECK-NEXT: vle8.v v8, (a0) -; CHECK-NEXT: lb a0, 0(a1) -; CHECK-NEXT: vzext.vf2 v9, v8 -; CHECK-NEXT: vmul.vx v8, v9, a0 +; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, mu +; CHECK-NEXT: vle8.v v9, (a0) +; CHECK-NEXT: vlse8.v v10, (a1), zero +; CHECK-NEXT: vwmulsu.vv v8, v10, v9 ; CHECK-NEXT: ret %a = load <8 x i8>, <8 x i8>* %x %b = load i8, i8* %y diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsub.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsub.ll index 4704a32..017a168 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsub.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsub.ll @@ -647,12 +647,10 @@ define <16 x i64> @vwsub_vx_v16i64(<16 x i32>* %x, i32 %y) { define <8 x i16> @vwsub_vx_v8i16_i8(<8 x i8>* %x, i8* %y) { ; CHECK-LABEL: vwsub_vx_v8i16_i8: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, mu -; CHECK-NEXT: lb a1, 0(a1) +; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, mu ; CHECK-NEXT: vle8.v v9, (a0) -; CHECK-NEXT: vmv.v.x v8, a1 -; CHECK-NEXT: vsetvli zero, zero, e8, mf2, ta, mu -; CHECK-NEXT: vwsub.wv v8, v8, v9 +; CHECK-NEXT: vlse8.v v10, (a1), zero +; CHECK-NEXT: vwsub.vv v8, v10, v9 ; CHECK-NEXT: ret %a = load <8 x i8>, <8 x i8>* %x %b = load i8, i8* %y @@ -684,12 +682,11 @@ define <8 x i16> @vwsub_vx_v8i16_i16(<8 x i8>* %x, i16* %y) { define <4 x i32> @vwsub_vx_v4i32_i8(<4 x i16>* %x, i8* %y) { ; CHECK-LABEL: vwsub_vx_v4i32_i8: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, mu +; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, mu ; CHECK-NEXT: lb a1, 0(a1) ; CHECK-NEXT: vle16.v v9, (a0) -; CHECK-NEXT: vmv.v.x v8, a1 -; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, mu -; CHECK-NEXT: vwsub.wv v8, v8, v9 +; CHECK-NEXT: vmv.v.x v10, a1 +; CHECK-NEXT: vwsub.vv v8, v10, v9 ; CHECK-NEXT: ret %a = load <4 x i16>, <4 x i16>* %x %b = load i8, i8* %y @@ -704,12 +701,10 @@ define <4 x i32> @vwsub_vx_v4i32_i8(<4 x i16>* %x, i8* %y) { define <4 x i32> @vwsub_vx_v4i32_i16(<4 x i16>* %x, i16* %y) { ; CHECK-LABEL: vwsub_vx_v4i32_i16: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, mu -; CHECK-NEXT: lh a1, 0(a1) +; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, mu ; CHECK-NEXT: vle16.v v9, (a0) -; CHECK-NEXT: vmv.v.x v8, a1 -; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, mu -; CHECK-NEXT: vwsub.wv v8, v8, v9 +; CHECK-NEXT: vlse16.v v10, (a1), zero +; CHECK-NEXT: vwsub.vv v8, v10, v9 ; CHECK-NEXT: ret %a = load <4 x i16>, <4 x i16>* %x %b = load i16, i16* %y @@ -756,12 +751,11 @@ define <2 x i64> @vwsub_vx_v2i64_i8(<2 x i32>* %x, i8* %y) nounwind { ; ; RV64-LABEL: vwsub_vx_v2i64_i8: ; RV64: # %bb.0: -; RV64-NEXT: vsetivli zero, 2, e64, m1, ta, mu +; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu ; RV64-NEXT: lb a1, 0(a1) ; RV64-NEXT: vle32.v v9, (a0) -; RV64-NEXT: vmv.v.x v8, a1 -; RV64-NEXT: vsetvli zero, zero, e32, mf2, ta, mu -; RV64-NEXT: vwsub.wv v8, v8, v9 +; RV64-NEXT: vmv.v.x v10, a1 +; RV64-NEXT: vwsub.vv v8, v10, v9 ; RV64-NEXT: ret %a = load <2 x i32>, <2 x i32>* %x %b = load i8, i8* %y @@ -791,12 +785,11 @@ define <2 x i64> @vwsub_vx_v2i64_i16(<2 x i32>* %x, i16* %y) nounwind { ; ; RV64-LABEL: vwsub_vx_v2i64_i16: ; RV64: # %bb.0: -; RV64-NEXT: vsetivli zero, 2, e64, m1, ta, mu +; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu ; RV64-NEXT: lh a1, 0(a1) ; RV64-NEXT: vle32.v v9, (a0) -; RV64-NEXT: vmv.v.x v8, a1 -; RV64-NEXT: vsetvli zero, zero, e32, mf2, ta, mu -; RV64-NEXT: vwsub.wv v8, v8, v9 +; RV64-NEXT: vmv.v.x v10, a1 +; RV64-NEXT: vwsub.vv v8, v10, v9 ; RV64-NEXT: ret %a = load <2 x i32>, <2 x i32>* %x %b = load i16, i16* %y @@ -826,12 +819,10 @@ define <2 x i64> @vwsub_vx_v2i64_i32(<2 x i32>* %x, i32* %y) nounwind { ; ; RV64-LABEL: vwsub_vx_v2i64_i32: ; RV64: # %bb.0: -; RV64-NEXT: vsetivli zero, 2, e64, m1, ta, mu -; RV64-NEXT: lw a1, 0(a1) +; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu ; RV64-NEXT: vle32.v v9, (a0) -; RV64-NEXT: vmv.v.x v8, a1 -; RV64-NEXT: vsetvli zero, zero, e32, mf2, ta, mu -; RV64-NEXT: vwsub.wv v8, v8, v9 +; RV64-NEXT: vlse32.v v10, (a1), zero +; RV64-NEXT: vwsub.vv v8, v10, v9 ; RV64-NEXT: ret %a = load <2 x i32>, <2 x i32>* %x %b = load i32, i32* %y diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll index 258fbbe..e443c9b 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll @@ -647,12 +647,10 @@ define <16 x i64> @vwsubu_vx_v16i64(<16 x i32>* %x, i32 %y) { define <8 x i16> @vwsubu_vx_v8i16_i8(<8 x i8>* %x, i8* %y) { ; CHECK-LABEL: vwsubu_vx_v8i16_i8: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, mu -; CHECK-NEXT: lbu a1, 0(a1) +; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, mu ; CHECK-NEXT: vle8.v v9, (a0) -; CHECK-NEXT: vmv.v.x v8, a1 -; CHECK-NEXT: vsetvli zero, zero, e8, mf2, ta, mu -; CHECK-NEXT: vwsubu.wv v8, v8, v9 +; CHECK-NEXT: vlse8.v v10, (a1), zero +; CHECK-NEXT: vwsubu.vv v8, v10, v9 ; CHECK-NEXT: ret %a = load <8 x i8>, <8 x i8>* %x %b = load i8, i8* %y @@ -684,12 +682,11 @@ define <8 x i16> @vwsubu_vx_v8i16_i16(<8 x i8>* %x, i16* %y) { define <4 x i32> @vwsubu_vx_v4i32_i8(<4 x i16>* %x, i8* %y) { ; CHECK-LABEL: vwsubu_vx_v4i32_i8: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, mu +; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, mu ; CHECK-NEXT: lbu a1, 0(a1) ; CHECK-NEXT: vle16.v v9, (a0) -; CHECK-NEXT: vmv.v.x v8, a1 -; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, mu -; CHECK-NEXT: vwsubu.wv v8, v8, v9 +; CHECK-NEXT: vmv.v.x v10, a1 +; CHECK-NEXT: vwsubu.vv v8, v10, v9 ; CHECK-NEXT: ret %a = load <4 x i16>, <4 x i16>* %x %b = load i8, i8* %y @@ -704,12 +701,10 @@ define <4 x i32> @vwsubu_vx_v4i32_i8(<4 x i16>* %x, i8* %y) { define <4 x i32> @vwsubu_vx_v4i32_i16(<4 x i16>* %x, i16* %y) { ; CHECK-LABEL: vwsubu_vx_v4i32_i16: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, mu -; CHECK-NEXT: lhu a1, 0(a1) +; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, mu ; CHECK-NEXT: vle16.v v9, (a0) -; CHECK-NEXT: vmv.v.x v8, a1 -; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, mu -; CHECK-NEXT: vwsubu.wv v8, v8, v9 +; CHECK-NEXT: vlse16.v v10, (a1), zero +; CHECK-NEXT: vwsubu.vv v8, v10, v9 ; CHECK-NEXT: ret %a = load <4 x i16>, <4 x i16>* %x %b = load i16, i16* %y @@ -755,12 +750,11 @@ define <2 x i64> @vwsubu_vx_v2i64_i8(<2 x i32>* %x, i8* %y) nounwind { ; ; RV64-LABEL: vwsubu_vx_v2i64_i8: ; RV64: # %bb.0: -; RV64-NEXT: vsetivli zero, 2, e64, m1, ta, mu +; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu ; RV64-NEXT: lbu a1, 0(a1) ; RV64-NEXT: vle32.v v9, (a0) -; RV64-NEXT: vmv.v.x v8, a1 -; RV64-NEXT: vsetvli zero, zero, e32, mf2, ta, mu -; RV64-NEXT: vwsubu.wv v8, v8, v9 +; RV64-NEXT: vmv.v.x v10, a1 +; RV64-NEXT: vwsubu.vv v8, v10, v9 ; RV64-NEXT: ret %a = load <2 x i32>, <2 x i32>* %x %b = load i8, i8* %y @@ -789,12 +783,11 @@ define <2 x i64> @vwsubu_vx_v2i64_i16(<2 x i32>* %x, i16* %y) nounwind { ; ; RV64-LABEL: vwsubu_vx_v2i64_i16: ; RV64: # %bb.0: -; RV64-NEXT: vsetivli zero, 2, e64, m1, ta, mu +; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu ; RV64-NEXT: lhu a1, 0(a1) ; RV64-NEXT: vle32.v v9, (a0) -; RV64-NEXT: vmv.v.x v8, a1 -; RV64-NEXT: vsetvli zero, zero, e32, mf2, ta, mu -; RV64-NEXT: vwsubu.wv v8, v8, v9 +; RV64-NEXT: vmv.v.x v10, a1 +; RV64-NEXT: vwsubu.vv v8, v10, v9 ; RV64-NEXT: ret %a = load <2 x i32>, <2 x i32>* %x %b = load i16, i16* %y @@ -823,12 +816,10 @@ define <2 x i64> @vwsubu_vx_v2i64_i32(<2 x i32>* %x, i32* %y) nounwind { ; ; RV64-LABEL: vwsubu_vx_v2i64_i32: ; RV64: # %bb.0: -; RV64-NEXT: vsetivli zero, 2, e64, m1, ta, mu -; RV64-NEXT: lwu a1, 0(a1) +; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu ; RV64-NEXT: vle32.v v9, (a0) -; RV64-NEXT: vmv.v.x v8, a1 -; RV64-NEXT: vsetvli zero, zero, e32, mf2, ta, mu -; RV64-NEXT: vwsubu.wv v8, v8, v9 +; RV64-NEXT: vlse32.v v10, (a1), zero +; RV64-NEXT: vwsubu.vv v8, v10, v9 ; RV64-NEXT: ret %a = load <2 x i32>, <2 x i32>* %x %b = load i32, i32* %y -- 2.7.4