[AArch64] Allow poison elements of fixed-vectors to be duplicated as a widened element
authorMatt Devereau <matthew.devereau@arm.com>
Mon, 16 Jan 2023 14:21:18 +0000 (14:21 +0000)
committerMatt Devereau <matthew.devereau@arm.com>
Thu, 19 Jan 2023 16:01:30 +0000 (16:01 +0000)
Expanding upon https://reviews.llvm.org/D138203, allow null indices in
InsertElts to be matched with any value and be duplicated if the fixed
vector the scalar values are inserted into is poison, and the scalable vector
the subvector being inserted into is poison.

Differential Revision: https://reviews.llvm.org/D141846

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-dupqlane.ll

index 471b05b..916eefc 100644 (file)
@@ -1436,7 +1436,7 @@ static std::optional<Instruction *> instCombineSVESDIV(InstCombiner &IC,
   return std::nullopt;
 }
 
-bool SimplifyValuePattern(SmallVector<Value *> &Vec) {
+bool SimplifyValuePattern(SmallVector<Value *> &Vec, bool AllowPoison) {
   size_t VecSize = Vec.size();
   if (VecSize == 1)
     return true;
@@ -1446,13 +1446,20 @@ bool SimplifyValuePattern(SmallVector<Value *> &Vec) {
 
   for (auto LHS = Vec.begin(), RHS = Vec.begin() + HalfVecSize;
        RHS != Vec.end(); LHS++, RHS++) {
-    if (*LHS != nullptr && *RHS != nullptr && *LHS == *RHS)
-      continue;
-    return false;
+    if (*LHS != nullptr && *RHS != nullptr) {
+      if (*LHS == *RHS)
+        continue;
+      else
+        return false;
+    }
+    if (!AllowPoison)
+      return false;
+    if (*LHS == nullptr && *RHS != nullptr)
+      *LHS = *RHS;
   }
 
   Vec.resize(HalfVecSize);
-  SimplifyValuePattern(Vec);
+  SimplifyValuePattern(Vec, AllowPoison);
   return true;
 }
 
@@ -1476,7 +1483,9 @@ static std::optional<Instruction *> instCombineSVEDupqLane(InstCombiner &IC,
     CurrentInsertElt = InsertElt->getOperand(0);
   }
 
-  if (!SimplifyValuePattern(Elts))
+  bool AllowPoison =
+      isa<PoisonValue>(CurrentInsertElt) && isa<PoisonValue>(Default);
+  if (!SimplifyValuePattern(Elts, AllowPoison))
     return std::nullopt;
 
   // Rebuild the simplified chain of InsertElements. e.g. (a, b, a, b) as (a, b)
@@ -1484,9 +1493,13 @@ static std::optional<Instruction *> instCombineSVEDupqLane(InstCombiner &IC,
   Builder.SetInsertPoint(&II);
   Value *InsertEltChain = PoisonValue::get(CurrentInsertElt->getType());
   for (size_t I = 0; I < Elts.size(); I++) {
+    if (Elts[I] == nullptr)
+      continue;
     InsertEltChain = Builder.CreateInsertElement(InsertEltChain, Elts[I],
                                                  Builder.getInt64(I));
   }
+  if (InsertEltChain == nullptr)
+    return std::nullopt;
 
   // Splat the simplified sequence, e.g. (f16 a, f16 b, f16 c, f16 d) as one i64
   // value or (f16 a, f16 b) as one i32 value. This requires an InsertSubvector
index 9b37583..d059670 100644 (file)
@@ -96,12 +96,11 @@ define dso_local <vscale x 8 x half> @dupq_f16_abcnull_pattern(half %a, half %b,
 ; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <8 x half> poison, half [[A:%.*]], i64 0
 ; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <8 x half> [[TMP1]], half [[B:%.*]], i64 1
 ; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <8 x half> [[TMP2]], half [[C:%.*]], i64 2
-; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <8 x half> [[TMP3]], half [[A]], i64 4
-; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <8 x half> [[TMP4]], half [[B]], i64 5
-; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <8 x half> [[TMP5]], half [[C]], i64 6
-; CHECK-NEXT:    [[TMP7:%.*]] = tail call <vscale x 8 x half> @llvm.vector.insert.nxv8f16.v8f16(<vscale x 8 x half> poison, <8 x half> [[TMP6]], i64 0)
-; CHECK-NEXT:    [[TMP8:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.dupq.lane.nxv8f16(<vscale x 8 x half> [[TMP7]], i64 0)
-; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP8]]
+; CHECK-NEXT:    [[TMP4:%.*]] = call <vscale x 8 x half> @llvm.vector.insert.nxv8f16.v8f16(<vscale x 8 x half> poison, <8 x half> [[TMP3]], i64 0)
+; CHECK-NEXT:    [[TMP5:%.*]] = bitcast <vscale x 8 x half> [[TMP4]] to <vscale x 2 x i64>
+; CHECK-NEXT:    [[TMP6:%.*]] = shufflevector <vscale x 2 x i64> [[TMP5]], <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP7:%.*]] = bitcast <vscale x 2 x i64> [[TMP6]] to <vscale x 8 x half>
+; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP7]]
 ;
   %1 = insertelement <8 x half> poison, half %a, i64 0
   %2 = insertelement <8 x half> %1, half %b, i64 1
@@ -114,6 +113,57 @@ define dso_local <vscale x 8 x half> @dupq_f16_abcnull_pattern(half %a, half %b,
   ret <vscale x 8 x half> %8
 }
 
+define dso_local <vscale x 8 x half> @dupq_f16_abnull_pattern(half %a, half %b) {
+; CHECK-LABEL: @dupq_f16_abnull_pattern(
+; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <8 x half> poison, half [[A:%.*]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <8 x half> [[TMP1]], half [[B:%.*]], i64 1
+; CHECK-NEXT:    [[TMP3:%.*]] = call <vscale x 8 x half> @llvm.vector.insert.nxv8f16.v8f16(<vscale x 8 x half> poison, <8 x half> [[TMP2]], i64 0)
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <vscale x 8 x half> [[TMP3]] to <vscale x 4 x i32>
+; CHECK-NEXT:    [[TMP5:%.*]] = shufflevector <vscale x 4 x i32> [[TMP4]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast <vscale x 4 x i32> [[TMP5]] to <vscale x 8 x half>
+; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP6]]
+;
+  %1 = insertelement <8 x half> poison, half %a, i64 0
+  %2 = insertelement <8 x half> %1, half %b, i64 1
+  %3 = tail call <vscale x 8 x half> @llvm.vector.insert.nxv8f16.v8f16(<vscale x 8 x half> poison, <8 x half> %2, i64 0)
+  %4 = tail call <vscale x 8 x half> @llvm.aarch64.sve.dupq.lane.nxv8f16(<vscale x 8 x half> %3, i64 0)
+  ret <vscale x 8 x half> %4
+}
+
+define dso_local <vscale x 8 x half> @neg_dupq_f16_non_poison_fixed(half %a, half %b, <8 x half> %v) {
+; CHECK-LABEL: @neg_dupq_f16_non_poison_fixed(
+; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <8 x half> [[V:%.*]], half [[A:%.*]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <8 x half> [[TMP1]], half [[B:%.*]], i64 1
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call <vscale x 8 x half> @llvm.vector.insert.nxv8f16.v8f16(<vscale x 8 x half> poison, <8 x half> [[TMP2]], i64 0)
+; CHECK-NEXT:    [[TMP4:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.dupq.lane.nxv8f16(<vscale x 8 x half> [[TMP3]], i64 0)
+; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP4]]
+;
+  %1 = insertelement <8 x half> %v, half %a, i64 0
+  %2 = insertelement <8 x half> %1, half %b, i64 1
+  %3 = insertelement <8 x half> %2, half %a, i64 0
+  %4 = insertelement <8 x half> %3, half %b, i64 1
+  %5 = tail call <vscale x 8 x half> @llvm.vector.insert.nxv8f16.v8f16(<vscale x 8 x half> poison, <8 x half> %4, i64 0)
+  %6 = tail call <vscale x 8 x half> @llvm.aarch64.sve.dupq.lane.nxv8f16(<vscale x 8 x half> %5, i64 0)
+  ret <vscale x 8 x half> %6
+}
+
+define dso_local <vscale x 8 x half> @neg_dupq_f16_into_non_poison_scalable(half %a, half %b, <vscale x 8 x half> %v) {
+; CHECK-LABEL: @neg_dupq_f16_into_non_poison_scalable(
+; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <8 x half> poison, half [[A:%.*]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <8 x half> [[TMP1]], half [[B:%.*]], i64 1
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call <vscale x 8 x half> @llvm.vector.insert.nxv8f16.v8f16(<vscale x 8 x half> [[V:%.*]], <8 x half> [[TMP2]], i64 0)
+; CHECK-NEXT:    [[TMP4:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.dupq.lane.nxv8f16(<vscale x 8 x half> [[TMP3]], i64 0)
+; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP4]]
+;
+  %1 = insertelement <8 x half> poison, half %a, i64 0
+  %2 = insertelement <8 x half> %1, half %b, i64 1
+  %3 = insertelement <8 x half> %2, half %a, i64 0
+  %4 = insertelement <8 x half> %3, half %b, i64 1
+  %5 = tail call <vscale x 8 x half> @llvm.vector.insert.nxv8f16.v8f16(<vscale x 8 x half> %v, <8 x half> %4, i64 0)
+  %6 = tail call <vscale x 8 x half> @llvm.aarch64.sve.dupq.lane.nxv8f16(<vscale x 8 x half> %5, i64 0)
+  ret <vscale x 8 x half> %6
+}
+
 ; Insert %c to override the last element in the insertelement chain, which will fail to combine
 
 define dso_local <vscale x 8 x half> @neg_dupq_f16_abcd_pattern_double_insert(half %a, half %b, half %c, half %d) {