[RISCV] Fix crash a vector add has a 4x sext and zext operand.
authorCraig Topper <craig.topper@sifive.com>
Mon, 31 Oct 2022 19:37:51 +0000 (12:37 -0700)
committerCraig Topper <craig.topper@sifive.com>
Mon, 31 Oct 2022 22:10:27 +0000 (15:10 -0700)
We can narrow one of the extends and keep the other original by
using a vwaddu.wv or vwadd.wv.

We were previously forgetting to keep the original operand and
instead took the source of its extend. This resulted in a type
mismatch that later failed with an impossible physical register copy.

To fix this I've refactored some code to maintain information about
whether the source needs to be extended at all for longer so we could
use it in materialize.

Differential Revision: https://reviews.llvm.org/D137106

llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll

index 5ecf468..9db4a3f 100644 (file)
@@ -8450,24 +8450,29 @@ struct NodeExtensionHelper {
     return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
   }
 
-  /// Get or create a value that can feed \p Root with the given \p ExtOpc.
-  /// If \p ExtOpc is None, this returns the source of this operand.
+  /// Get or create a value that can feed \p Root with the given extension \p
+  /// SExt. If \p SExt is None, this returns the source of this operand.
   /// \see ::getSource().
   SDValue getOrCreateExtendedOp(const SDNode *Root, SelectionDAG &DAG,
-                                Optional<unsigned> ExtOpc) const {
+                                Optional<bool> SExt) const {
+    if (!SExt.has_value())
+      return OrigOperand;
+
+    MVT NarrowVT = getNarrowType(Root);
+
     SDValue Source = getSource();
-    if (!ExtOpc)
+    if (Source.getValueType() == NarrowVT)
       return Source;
 
-    MVT NarrowVT = getNarrowType(Root);
+    unsigned ExtOpc = *SExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
+
     // If we need an extension, we should be changing the type.
-    assert(Source.getValueType() != NarrowVT && "Needless extension");
     SDLoc DL(Root);
     auto [Mask, VL] = getMaskAndVL(Root);
     switch (OrigOperand.getOpcode()) {
     case RISCVISD::VSEXT_VL:
     case RISCVISD::VZEXT_VL:
-      return DAG.getNode(*ExtOpc, DL, NarrowVT, Source, Mask, VL);
+      return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
     case RISCVISD::VMV_V_X_VL:
       return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
                          DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL);
@@ -8712,13 +8717,10 @@ struct NodeExtensionHelper {
 struct CombineResult {
   /// Opcode to be generated when materializing the combine.
   unsigned TargetOpcode;
-  /// Extension opcode to be applied to the source of LHS when materializing
-  /// TargetOpcode.
-  /// \see NodeExtensionHelper::getSource().
-  Optional<unsigned> LHSExtOpc;
-  /// Extension opcode to be applied to the source of RHS when materializing
-  /// TargetOpcode.
-  Optional<unsigned> RHSExtOpc;
+  // No value means no extension is needed. If extension is needed, the value
+  // indicates if it needs to be sign extended.
+  Optional<bool> SExtLHS;
+  Optional<bool> SExtRHS;
   /// Root of the combine.
   SDNode *Root;
   /// LHS of the TargetOpcode.
@@ -8729,13 +8731,8 @@ struct CombineResult {
   CombineResult(unsigned TargetOpcode, SDNode *Root,
                 const NodeExtensionHelper &LHS, Optional<bool> SExtLHS,
                 const NodeExtensionHelper &RHS, Optional<bool> SExtRHS)
-      : TargetOpcode(TargetOpcode), Root(Root), LHS(LHS), RHS(RHS) {
-    MVT NarrowVT = NodeExtensionHelper::getNarrowType(Root);
-    if (SExtLHS && LHS.getSource().getValueType() != NarrowVT)
-      LHSExtOpc = *SExtLHS ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
-    if (SExtRHS && RHS.getSource().getValueType() != NarrowVT)
-      RHSExtOpc = *SExtRHS ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
-  }
+      : TargetOpcode(TargetOpcode), SExtLHS(SExtLHS), SExtRHS(SExtRHS),
+        Root(Root), LHS(LHS), RHS(RHS) {}
 
   /// Return a value that uses TargetOpcode and that can be used to replace
   /// Root.
@@ -8745,8 +8742,8 @@ struct CombineResult {
     std::tie(Mask, VL) = NodeExtensionHelper::getMaskAndVL(Root);
     Merge = Root->getOperand(2);
     return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0),
-                       LHS.getOrCreateExtendedOp(Root, DAG, LHSExtOpc),
-                       RHS.getOrCreateExtendedOp(Root, DAG, RHSExtOpc), Merge,
+                       LHS.getOrCreateExtendedOp(Root, DAG, SExtLHS),
+                       RHS.getOrCreateExtendedOp(Root, DAG, SExtRHS), Merge,
                        Mask, VL);
   }
 };
index 7872515..9762738 100644 (file)
@@ -859,3 +859,19 @@ define <2 x i64> @vwaddu_vx_v2i64_i64(<2 x i32>* %x, i64* %y) nounwind {
   %g = add <2 x i64> %e, %f
   ret <2 x i64> %g
 }
+
+define <4 x i64> @crash(<4 x i16> %x, <4 x i16> %y) {
+; CHECK-LABEL: crash:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e64, m2, ta, ma
+; CHECK-NEXT:    vsext.vf4 v10, v8
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vzext.vf2 v8, v9
+; CHECK-NEXT:    vwaddu.wv v10, v10, v8
+; CHECK-NEXT:    vmv2r.v v8, v10
+; CHECK-NEXT:    ret
+  %a = sext <4 x i16> %x to <4 x i64>
+  %b = zext <4 x i16> %y to <4 x i64>
+  %c = add <4 x i64> %a, %b
+  ret <4 x i64> %c
+}