[RISCV] Fix crash a vector add has a 4x sext and zext operand.
authorCraig Topper <craig.topper@sifive.com>
Mon, 31 Oct 2022 19:37:51 +0000 (12:37 -0700)
committerCraig Topper <craig.topper@sifive.com>
Mon, 31 Oct 2022 22:10:27 +0000 (15:10 -0700)
We can narrow one of the extends and keep the other original by
using a vwaddu.wv or vwadd.wv.

We were previously forgetting to keep the original operand and
instead took the source of its extend. This resulted in a type
mismatch that later failed with an impossible physical register copy.

To fix this I've refactored some code to maintain information about
whether the source needs to be extended at all for longer so we could
use it in materialize.

Differential Revision: https://reviews.llvm.org/D137106

llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll

index 5ecf46806ceb62519dae307bc90c22261c82f93b..9db4a3fe32fb833eba1ec17276d2396c2063bc11 100644 (file)
@@ -8450,24 +8450,29 @@ struct NodeExtensionHelper {
     return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
   }
 
-  /// Get or create a value that can feed \p Root with the given \p ExtOpc.
-  /// If \p ExtOpc is None, this returns the source of this operand.
+  /// Get or create a value that can feed \p Root with the given extension \p
+  /// SExt. If \p SExt is None, this returns the source of this operand.
   /// \see ::getSource().
   SDValue getOrCreateExtendedOp(const SDNode *Root, SelectionDAG &DAG,
-                                Optional<unsigned> ExtOpc) const {
+                                Optional<bool> SExt) const {
+    if (!SExt.has_value())
+      return OrigOperand;
+
+    MVT NarrowVT = getNarrowType(Root);
+
     SDValue Source = getSource();
-    if (!ExtOpc)
+    if (Source.getValueType() == NarrowVT)
       return Source;
 
-    MVT NarrowVT = getNarrowType(Root);
+    unsigned ExtOpc = *SExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
+
     // If we need an extension, we should be changing the type.
-    assert(Source.getValueType() != NarrowVT && "Needless extension");
     SDLoc DL(Root);
     auto [Mask, VL] = getMaskAndVL(Root);
     switch (OrigOperand.getOpcode()) {
     case RISCVISD::VSEXT_VL:
     case RISCVISD::VZEXT_VL:
-      return DAG.getNode(*ExtOpc, DL, NarrowVT, Source, Mask, VL);
+      return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
     case RISCVISD::VMV_V_X_VL:
       return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
                          DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL);
@@ -8712,13 +8717,10 @@ struct NodeExtensionHelper {
 struct CombineResult {
   /// Opcode to be generated when materializing the combine.
   unsigned TargetOpcode;
-  /// Extension opcode to be applied to the source of LHS when materializing
-  /// TargetOpcode.
-  /// \see NodeExtensionHelper::getSource().
-  Optional<unsigned> LHSExtOpc;
-  /// Extension opcode to be applied to the source of RHS when materializing
-  /// TargetOpcode.
-  Optional<unsigned> RHSExtOpc;
+  // No value means no extension is needed. If extension is needed, the value
+  // indicates if it needs to be sign extended.
+  Optional<bool> SExtLHS;
+  Optional<bool> SExtRHS;
   /// Root of the combine.
   SDNode *Root;
   /// LHS of the TargetOpcode.
@@ -8729,13 +8731,8 @@ struct CombineResult {
   CombineResult(unsigned TargetOpcode, SDNode *Root,
                 const NodeExtensionHelper &LHS, Optional<bool> SExtLHS,
                 const NodeExtensionHelper &RHS, Optional<bool> SExtRHS)
-      : TargetOpcode(TargetOpcode), Root(Root), LHS(LHS), RHS(RHS) {
-    MVT NarrowVT = NodeExtensionHelper::getNarrowType(Root);
-    if (SExtLHS && LHS.getSource().getValueType() != NarrowVT)
-      LHSExtOpc = *SExtLHS ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
-    if (SExtRHS && RHS.getSource().getValueType() != NarrowVT)
-      RHSExtOpc = *SExtRHS ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
-  }
+      : TargetOpcode(TargetOpcode), SExtLHS(SExtLHS), SExtRHS(SExtRHS),
+        Root(Root), LHS(LHS), RHS(RHS) {}
 
   /// Return a value that uses TargetOpcode and that can be used to replace
   /// Root.
@@ -8745,8 +8742,8 @@ struct CombineResult {
     std::tie(Mask, VL) = NodeExtensionHelper::getMaskAndVL(Root);
     Merge = Root->getOperand(2);
     return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0),
-                       LHS.getOrCreateExtendedOp(Root, DAG, LHSExtOpc),
-                       RHS.getOrCreateExtendedOp(Root, DAG, RHSExtOpc), Merge,
+                       LHS.getOrCreateExtendedOp(Root, DAG, SExtLHS),
+                       RHS.getOrCreateExtendedOp(Root, DAG, SExtRHS), Merge,
                        Mask, VL);
   }
 };
index 787251565f2820af51e2f7e5c90086d8d8b915bb..976273863be8d20552e967504a6064aee13a0054 100644 (file)
@@ -859,3 +859,19 @@ define <2 x i64> @vwaddu_vx_v2i64_i64(<2 x i32>* %x, i64* %y) nounwind {
   %g = add <2 x i64> %e, %f
   ret <2 x i64> %g
 }
+
+define <4 x i64> @crash(<4 x i16> %x, <4 x i16> %y) {
+; CHECK-LABEL: crash:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e64, m2, ta, ma
+; CHECK-NEXT:    vsext.vf4 v10, v8
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vzext.vf2 v8, v9
+; CHECK-NEXT:    vwaddu.wv v10, v10, v8
+; CHECK-NEXT:    vmv2r.v v8, v10
+; CHECK-NEXT:    ret
+  %a = sext <4 x i16> %x to <4 x i64>
+  %b = zext <4 x i16> %y to <4 x i64>
+  %c = add <4 x i64> %a, %b
+  ret <4 x i64> %c
+}