[AArch64][SVE] Ensure PTEST operands have type nxv16i1
authorRosie Sumpter <rosie.sumpter@arm.com>
Thu, 30 Jun 2022 11:15:00 +0000 (12:15 +0100)
committerRosie Sumpter <rosie.sumpter@arm.com>
Tue, 12 Jul 2022 08:27:59 +0000 (09:27 +0100)
Currently any legal predicate types will be pattern-matched when
creating a PTEST instruction. This could be a problem in future since
PTEST always uses the .B specifier for the operand, but it is not
always guaranteed that the extra lanes of unpacked types (e.g. nxv4i1)
are zero. This patch ensures the operands of PTEST are type nxv16i1,
where the undef lanes are set to zero.

Differential Revision: https://reviews.llvm.org/D129282/

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
llvm/lib/Target/AArch64/SVEInstrFormats.td
llvm/test/CodeGen/AArch64/sve-setcc.ll

index 8108b77..447ad10 100644 (file)
@@ -237,6 +237,39 @@ static bool isMergePassthruOpcode(unsigned Opc) {
   }
 }
 
+// Returns true if inactive lanes are known to be zeroed by construction.
+static bool isZeroingInactiveLanes(SDValue Op) {
+  switch (Op.getOpcode()) {
+  default:
+    // We guarantee i1 splat_vectors to zero the other lanes by
+    // implementing it with ptrue and possibly a punpklo for nxv1i1.
+    if (ISD::isConstantSplatVectorAllOnes(Op.getNode()))
+      return true;
+    return false;
+  case AArch64ISD::PTRUE:
+  case AArch64ISD::SETCC_MERGE_ZERO:
+    return true;
+  case ISD::INTRINSIC_WO_CHAIN:
+    switch (Op.getConstantOperandVal(0)) {
+    default:
+      return false;
+    case Intrinsic::aarch64_sve_ptrue:
+    case Intrinsic::aarch64_sve_pnext:
+    case Intrinsic::aarch64_sve_cmpeq_wide:
+    case Intrinsic::aarch64_sve_cmpne_wide:
+    case Intrinsic::aarch64_sve_cmpge_wide:
+    case Intrinsic::aarch64_sve_cmpgt_wide:
+    case Intrinsic::aarch64_sve_cmplt_wide:
+    case Intrinsic::aarch64_sve_cmple_wide:
+    case Intrinsic::aarch64_sve_cmphs_wide:
+    case Intrinsic::aarch64_sve_cmphi_wide:
+    case Intrinsic::aarch64_sve_cmplo_wide:
+    case Intrinsic::aarch64_sve_cmpls_wide:
+      return true;
+    }
+  }
+}
+
 AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
                                              const AArch64Subtarget &STI)
     : TargetLowering(TM), Subtarget(&STI) {
@@ -4368,16 +4401,18 @@ static inline SDValue getPTrue(SelectionDAG &DAG, SDLoc DL, EVT VT,
                      DAG.getTargetConstant(Pattern, DL, MVT::i32));
 }
 
-SDValue AArch64TargetLowering::getSVEPredicateBitCast(EVT VT, SDValue Op,
-                                                      SelectionDAG &DAG) const {
+// Returns a safe bitcast between two scalable vector predicates, where
+// any newly created lanes from a widening bitcast are defined as zero.
+static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) {
   SDLoc DL(Op);
   EVT InVT = Op.getValueType();
 
   assert(InVT.getVectorElementType() == MVT::i1 &&
          VT.getVectorElementType() == MVT::i1 &&
          "Expected a predicate-to-predicate bitcast");
-  assert(VT.isScalableVector() && isTypeLegal(VT) &&
-         InVT.isScalableVector() && isTypeLegal(InVT) &&
+  assert(VT.isScalableVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
+         InVT.isScalableVector() &&
+         DAG.getTargetLoweringInfo().isTypeLegal(InVT) &&
          "Only expect to cast between legal scalable predicate types!");
 
   // Return the operand if the cast isn't changing type,
@@ -4396,33 +4431,8 @@ SDValue AArch64TargetLowering::getSVEPredicateBitCast(EVT VT, SDValue Op,
 
   // Check if the other lanes are already known to be zeroed by
   // construction.
-  switch (Op.getOpcode()) {
-  default:
-    // We guarantee i1 splat_vectors to zero the other lanes by
-    // implementing it with ptrue and possibly a punpklo for nxv1i1.
-    if (ISD::isConstantSplatVectorAllOnes(Op.getNode()))
-      return Reinterpret;
-    break;
-  case AArch64ISD::SETCC_MERGE_ZERO:
+  if (isZeroingInactiveLanes(Op))
     return Reinterpret;
-  case ISD::INTRINSIC_WO_CHAIN:
-    switch (Op.getConstantOperandVal(0)) {
-    default:
-      break;
-    case Intrinsic::aarch64_sve_ptrue:
-    case Intrinsic::aarch64_sve_cmpeq_wide:
-    case Intrinsic::aarch64_sve_cmpne_wide:
-    case Intrinsic::aarch64_sve_cmpge_wide:
-    case Intrinsic::aarch64_sve_cmpgt_wide:
-    case Intrinsic::aarch64_sve_cmplt_wide:
-    case Intrinsic::aarch64_sve_cmple_wide:
-    case Intrinsic::aarch64_sve_cmphs_wide:
-    case Intrinsic::aarch64_sve_cmphi_wide:
-    case Intrinsic::aarch64_sve_cmplo_wide:
-    case Intrinsic::aarch64_sve_cmpls_wide:
-      return Reinterpret;
-    }
-  }
 
   // Zero the newly introduced lanes.
   SDValue Mask = DAG.getConstant(1, DL, InVT);
@@ -16164,12 +16174,24 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
   assert(Op.getValueType().isScalableVector() &&
          TLI.isTypeLegal(Op.getValueType()) &&
          "Expected legal scalable vector type!");
+  assert(Op.getValueType() == Pg.getValueType() &&
+         "Expected same type for PTEST operands");
 
   // Ensure target specific opcodes are using legal type.
   EVT OutVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
   SDValue TVal = DAG.getConstant(1, DL, OutVT);
   SDValue FVal = DAG.getConstant(0, DL, OutVT);
 
+  // Ensure operands have type nxv16i1.
+  if (Op.getValueType() != MVT::nxv16i1) {
+    if ((Cond == AArch64CC::ANY_ACTIVE || Cond == AArch64CC::NONE_ACTIVE) &&
+        isZeroingInactiveLanes(Op))
+      Pg = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv16i1, Pg);
+    else
+      Pg = getSVEPredicateBitCast(MVT::nxv16i1, Pg, DAG);
+    Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv16i1, Op);
+  }
+
   // Set condition code (CC) flags.
   SDValue Test = DAG.getNode(AArch64ISD::PTEST, DL, MVT::Other, Pg, Op);
 
index 48a559b..e02b5e5 100644 (file)
@@ -1154,10 +1154,6 @@ private:
   // This function does not handle predicate bitcasts.
   SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;
 
-  // Returns a safe bitcast between two scalable vector predicates, where
-  // any newly created lanes from a widening bitcast are defined as zero.
-  SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;
-
   bool isConstantUnsignedBitfieldExtractLegal(unsigned Opc, LLT Ty1,
                                               LLT Ty2) const override;
 };
index 58ef4b3..c66f9cf 100644 (file)
@@ -778,7 +778,7 @@ let Predicates = [HasSVEorSME] in {
   defm BRKB_PPmP  : sve_int_break_m<0b101, "brkb",  int_aarch64_sve_brkb>;
   defm BRKBS_PPzP : sve_int_break_z<0b110, "brkbs", null_frag>;
 
-  def PTEST_PP : sve_int_ptest<0b010000, "ptest">;
+  def PTEST_PP : sve_int_ptest<0b010000, "ptest", AArch64ptest>;
   defm PFALSE  : sve_int_pfalse<0b000000, "pfalse">;
   defm PFIRST  : sve_int_pfirst<0b00000, "pfirst", int_aarch64_sve_pfirst>;
   defm PNEXT   : sve_int_pnext<0b00110, "pnext", int_aarch64_sve_pnext>;
@@ -2131,17 +2131,6 @@ let Predicates = [HasSVEorSME] in {
     def STR_ZZZZXI : Pseudo<(outs), (ins ZZZZ_b:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
   }
 
-  def : Pat<(AArch64ptest (nxv16i1 PPR:$pg), (nxv16i1 PPR:$src)),
-            (PTEST_PP PPR:$pg, PPR:$src)>;
-  def : Pat<(AArch64ptest (nxv8i1 PPR:$pg), (nxv8i1 PPR:$src)),
-            (PTEST_PP PPR:$pg, PPR:$src)>;
-  def : Pat<(AArch64ptest (nxv4i1 PPR:$pg), (nxv4i1 PPR:$src)),
-            (PTEST_PP PPR:$pg, PPR:$src)>;
-  def : Pat<(AArch64ptest (nxv2i1 PPR:$pg), (nxv2i1 PPR:$src)),
-            (PTEST_PP PPR:$pg, PPR:$src)>;
-  def : Pat<(AArch64ptest (nxv1i1 PPR:$pg), (nxv1i1 PPR:$src)),
-            (PTEST_PP PPR:$pg, PPR:$src)>;
-
   let AddedComplexity = 1 in {
   class LD1RPat<ValueType vt, SDPatternOperator operator,
                 Instruction load, Instruction ptrue, ValueType index_vt, ComplexPattern CP, Operand immtype> :
index 80e38d0..7cdd4c4 100644 (file)
@@ -650,11 +650,11 @@ multiclass sve_int_pfalse<bits<6> opc, string asm> {
   def : Pat<(nxv1i1 immAllZerosV), (!cast<Instruction>(NAME))>;
 }
 
-class sve_int_ptest<bits<6> opc, string asm>
+class sve_int_ptest<bits<6> opc, string asm, SDPatternOperator op>
 : I<(outs), (ins PPRAny:$Pg, PPR8:$Pn),
   asm, "\t$Pg, $Pn",
   "",
-  []>, Sched<[]> {
+  [(op (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>, Sched<[]> {
   bits<4> Pg;
   bits<4> Pn;
   let Inst{31-24} = 0b00100101;
index 8d7aae8..60ee9b3 100644 (file)
@@ -51,7 +51,10 @@ if.end:
 define void @sve_cmplt_setcc_hslo(<vscale x 8 x i16>* %out, <vscale x 8 x i16> %in, <vscale x 8 x i1> %pg) {
 ; CHECK-LABEL: sve_cmplt_setcc_hslo:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    cmplt p1.h, p0/z, z0.h, #0
+; CHECK-NEXT:    ptrue p1.h
+; CHECK-NEXT:    cmplt p2.h, p0/z, z0.h, #0
+; CHECK-NEXT:    and p1.b, p0/z, p0.b, p1.b
+; CHECK-NEXT:    ptest p1, p2.b
 ; CHECK-NEXT:    b.hs .LBB2_2
 ; CHECK-NEXT:  // %bb.1: // %if.then
 ; CHECK-NEXT:    st1h { z0.h }, p0, [x0]