[SelectionDAG][AArch64] Restrict matchUnaryPredicate to only handle SPLAT_VECTOR...
authorCraig Topper <craig.topper@sifive.com>
Tue, 16 Feb 2021 17:22:44 +0000 (09:22 -0800)
committerCraig Topper <craig.topper@sifive.com>
Tue, 16 Feb 2021 17:22:46 +0000 (09:22 -0800)
fde24661718c7812a20a10e518cd853e8e060107 added support for
scalable vectors to matchUnaryPredicate by handling SPLAT_VECTOR in
addition to BUILD_VECTOR. This was used to enabled UDIV/SDIV/UREM/SREM
by constant expansion in BuildUDIV/BuildSDIV in TargetLowering.cpp

The caller there expects to call getBuildVector from the match factors.
This leads to a crash right now if there is a SPLAT_VECTOR of
fixed vectors since the number of vectors won't match the number
of elements.

To fix this, this patch updates the callers to check the opcode
instead of whether the type is fixed or scalable. This assumes
that only 3 opcodes are handled by matchUnaryPredicate so
I've added an assertion to the final else to check that opcode.

Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D96174

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
llvm/test/CodeGen/AArch64/sve-fixed-length-int-div.ll

index ce1cd3d..500bdb4 100644 (file)
@@ -5032,16 +5032,17 @@ static SDValue BuildExactSDIV(const TargetLowering &TLI, SDNode *N,
     return SDValue();
 
   SDValue Shift, Factor;
-  if (VT.isFixedLengthVector()) {
+  if (Op1.getOpcode() == ISD::BUILD_VECTOR) {
     Shift = DAG.getBuildVector(ShVT, dl, Shifts);
     Factor = DAG.getBuildVector(VT, dl, Factors);
-  } else if (VT.isScalableVector()) {
+  } else if (Op1.getOpcode() == ISD::SPLAT_VECTOR) {
     assert(Shifts.size() == 1 && Factors.size() == 1 &&
            "Expected matchUnaryPredicate to return one element for scalable "
            "vectors");
     Shift = DAG.getSplatVector(ShVT, dl, Shifts[0]);
     Factor = DAG.getSplatVector(VT, dl, Factors[0]);
   } else {
+    assert(isa<ConstantSDNode>(Op1) && "Expected a constant");
     Shift = Shifts[0];
     Factor = Factors[0];
   }
@@ -5147,12 +5148,12 @@ SDValue TargetLowering::BuildSDIV(SDNode *N, SelectionDAG &DAG,
     return SDValue();
 
   SDValue MagicFactor, Factor, Shift, ShiftMask;
-  if (VT.isFixedLengthVector()) {
+  if (N1.getOpcode() == ISD::BUILD_VECTOR) {
     MagicFactor = DAG.getBuildVector(VT, dl, MagicFactors);
     Factor = DAG.getBuildVector(VT, dl, Factors);
     Shift = DAG.getBuildVector(ShVT, dl, Shifts);
     ShiftMask = DAG.getBuildVector(VT, dl, ShiftMasks);
-  } else if (VT.isScalableVector()) {
+  } else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
     assert(MagicFactors.size() == 1 && Factors.size() == 1 &&
            Shifts.size() == 1 && ShiftMasks.size() == 1 &&
            "Expected matchUnaryPredicate to return one element for scalable "
@@ -5162,6 +5163,7 @@ SDValue TargetLowering::BuildSDIV(SDNode *N, SelectionDAG &DAG,
     Shift = DAG.getSplatVector(ShVT, dl, Shifts[0]);
     ShiftMask = DAG.getSplatVector(VT, dl, ShiftMasks[0]);
   } else {
+    assert(isa<ConstantSDNode>(N1) && "Expected a constant");
     MagicFactor = MagicFactors[0];
     Factor = Factors[0];
     Shift = Shifts[0];
@@ -5303,12 +5305,12 @@ SDValue TargetLowering::BuildUDIV(SDNode *N, SelectionDAG &DAG,
     return SDValue();
 
   SDValue PreShift, PostShift, MagicFactor, NPQFactor;
-  if (VT.isFixedLengthVector()) {
+  if (N1.getOpcode() == ISD::BUILD_VECTOR) {
     PreShift = DAG.getBuildVector(ShVT, dl, PreShifts);
     MagicFactor = DAG.getBuildVector(VT, dl, MagicFactors);
     NPQFactor = DAG.getBuildVector(VT, dl, NPQFactors);
     PostShift = DAG.getBuildVector(ShVT, dl, PostShifts);
-  } else if (VT.isScalableVector()) {
+  } else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
     assert(PreShifts.size() == 1 && MagicFactors.size() == 1 &&
            NPQFactors.size() == 1 && PostShifts.size() == 1 &&
            "Expected matchUnaryPredicate to return one for scalable vectors");
@@ -5317,6 +5319,7 @@ SDValue TargetLowering::BuildUDIV(SDNode *N, SelectionDAG &DAG,
     NPQFactor = DAG.getSplatVector(VT, dl, NPQFactors[0]);
     PostShift = DAG.getSplatVector(ShVT, dl, PostShifts[0]);
   } else {
+    assert(isa<ConstantSDNode>(N1) && "Expected a constant");
     PreShift = PreShifts[0];
     MagicFactor = MagicFactors[0];
     PostShift = PostShifts[0];
@@ -5806,7 +5809,7 @@ TargetLowering::prepareSREMEqFold(EVT SETCCVT, SDValue REMNode,
     return SDValue();
 
   SDValue PVal, AVal, KVal, QVal;
-  if (VT.isFixedLengthVector()) {
+  if (D.getOpcode() == ISD::BUILD_VECTOR) {
     if (HadOneDivisor) {
       // Try to turn PAmts into a splat, since we don't care about the values
       // that are currently '0'. If we can't, just keep '0'`s.
@@ -5825,7 +5828,7 @@ TargetLowering::prepareSREMEqFold(EVT SETCCVT, SDValue REMNode,
     AVal = DAG.getBuildVector(VT, DL, AAmts);
     KVal = DAG.getBuildVector(ShVT, DL, KAmts);
     QVal = DAG.getBuildVector(VT, DL, QAmts);
-  } else if (VT.isScalableVector()) {
+  } else if (D.getOpcode() == ISD::SPLAT_VECTOR) {
     assert(PAmts.size() == 1 && AAmts.size() == 1 && KAmts.size() == 1 &&
            QAmts.size() == 1 &&
            "Expected matchUnaryPredicate to return one element for scalable "
@@ -5835,6 +5838,7 @@ TargetLowering::prepareSREMEqFold(EVT SETCCVT, SDValue REMNode,
     KVal = DAG.getSplatVector(ShVT, DL, KAmts[0]);
     QVal = DAG.getSplatVector(VT, DL, QAmts[0]);
   } else {
+    assert(isa<ConstantSDNode>(D) && "Expected a constant");
     PVal = PAmts[0];
     AVal = AAmts[0];
     KVal = KAmts[0];
index b6e0655..b702282 100644 (file)
@@ -968,4 +968,20 @@ define void @udiv_v32i64(<32 x i64>* %a, <32 x i64>* %b) #0 {
   ret void
 }
 
+; This used to crash because isUnaryPredicate and BuildUDIV don't know how
+; a SPLAT_VECTOR of fixed vector type should be handled.
+define void @udiv_constantsplat_v8i32(<8 x i32>* %a) #0 {
+; CHECK-LABEL: udiv_constantsplat_v8i32:
+; CHECK: ptrue [[PG:p[0-9]+]].s, vl[[#min(div(VBYTES,4),8)]]
+; CHECK-NEXT: ld1w { [[OP1:z[0-9]+]].s }, [[PG]]/z, [x0]
+; CHECK-NEXT: mov [[OP2:z[0-9]+]].s, #95
+; CHECK-NEXT: udiv [[RES:z[0-9]+]].s, [[PG]]/m, [[OP1]].s, [[OP2]].s
+; CHECK-NEXT: st1w { [[RES]].s }, [[PG]], [x0]
+; CHECK-NEXT: ret
+  %op1 = load <8 x i32>, <8 x i32>* %a
+  %res = udiv <8 x i32> %op1, <i32 95, i32 95, i32 95, i32 95, i32 95, i32 95, i32 95, i32 95>
+  store <8 x i32> %res, <8 x i32>* %a
+  ret void
+}
+
 attributes #0 = { "target-features"="+sve" }