[SVE] Fix selection failure when splitting extended masked loads
authorKerry McLaughlin <kerry.mclaughlin@arm.com>
Thu, 21 Oct 2021 10:30:31 +0000 (11:30 +0100)
committerKerry McLaughlin <kerry.mclaughlin@arm.com>
Thu, 21 Oct 2021 12:04:38 +0000 (13:04 +0100)
When splitting a masked load, `GetDependentSplitDestVTs` is used to get the
MemVTs of the high and low parts. If the masked load is extended, this
may return VTs with different element types which are used to create the
high & low masked load instructions.
This patch changes `GetDependentSplitDestVTs` to ensure we return VTs with
the same element type.

Reviewed By: david-arm

Differential Revision: https://reviews.llvm.org/D111996

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll
llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll

index 0a5dce3..b928fd3 100644 (file)
@@ -10642,14 +10642,14 @@ SelectionDAG::GetDependentSplitDestVTs(const EVT &VT, const EVT &EnvVT,
          "Mixing fixed width and scalable vectors when enveloping a type");
   EVT LoVT, HiVT;
   if (VTNumElts.getKnownMinValue() > EnvNumElts.getKnownMinValue()) {
-    LoVT = EnvVT;
+    LoVT = EVT::getVectorVT(*getContext(), EltTp, EnvNumElts);
     HiVT = EVT::getVectorVT(*getContext(), EltTp, VTNumElts - EnvNumElts);
     *HiIsEmpty = false;
   } else {
     // Flag that hi type has zero storage size, but return split envelop type
     // (this would be easier if vector types with zero elements were allowed).
     LoVT = EVT::getVectorVT(*getContext(), EltTp, VTNumElts);
-    HiVT = EnvVT;
+    HiVT = EVT::getVectorVT(*getContext(), EltTp, EnvNumElts);
     *HiIsEmpty = true;
   }
   return std::make_pair(LoVT, HiVT);
index f7efa54..9a20035 100644 (file)
@@ -70,9 +70,29 @@ define <vscale x 2 x i64> @masked_sload_passthru(<vscale x 2 x i32> *%a, <vscale
   ret <vscale x 2 x i64> %ext
 }
 
+; Return type requires splitting
+define <vscale x 16 x i32> @masked_sload_nxv16i8(<vscale x 16 x i8>* %a, <vscale x 16 x i1> %mask) {
+; CHECK-LABEL: masked_sload_nxv16i8:
+; CHECK:         punpklo p1.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    punpklo p2.h, p1.b
+; CHECK-NEXT:    punpkhi p1.h, p1.b
+; CHECK-NEXT:    ld1sb { z0.s }, p2/z, [x0]
+; CHECK-NEXT:    punpklo p2.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    ld1sb { z1.s }, p1/z, [x0, #1, mul vl]
+; CHECK-NEXT:    ld1sb { z2.s }, p2/z, [x0, #2, mul vl]
+; CHECK-NEXT:    ld1sb { z3.s }, p0/z, [x0, #3, mul vl]
+; CHECK-NEXT:    ret
+  %load = call <vscale x 16 x i8> @llvm.masked.load.nxv16i8(<vscale x 16 x i8>* %a, i32 2, <vscale x 16 x i1> %mask, <vscale x 16 x i8> undef)
+  %ext = sext <vscale x 16 x i8> %load to <vscale x 16 x i32>
+  ret <vscale x 16 x i32> %ext
+}
+
 declare <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8>*, i32, <vscale x 2 x i1>, <vscale x 2 x i8>)
 declare <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16>*, i32, <vscale x 2 x i1>, <vscale x 2 x i16>)
 declare <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32>*, i32, <vscale x 2 x i1>, <vscale x 2 x i32>)
 declare <vscale x 4 x i8> @llvm.masked.load.nxv4i8(<vscale x 4 x i8>*, i32, <vscale x 4 x i1>, <vscale x 4 x i8>)
 declare <vscale x 4 x i16> @llvm.masked.load.nxv4i16(<vscale x 4 x i16>*, i32, <vscale x 4 x i1>, <vscale x 4 x i16>)
 declare <vscale x 8 x i8> @llvm.masked.load.nxv8i8(<vscale x 8 x i8>*, i32, <vscale x 8 x i1>, <vscale x 8 x i8>)
+declare <vscale x 16 x i8> @llvm.masked.load.nxv16i8(<vscale x 16 x i8>*, i32, <vscale x 16 x i1>, <vscale x 16 x i8>)
index 7dbebee..79eff4d 100644 (file)
@@ -76,9 +76,29 @@ define <vscale x 2 x i64> @masked_zload_passthru(<vscale x 2 x i32>* %src, <vsca
   ret <vscale x 2 x i64> %ext
 }
 
+; Return type requires splitting
+define <vscale x 8 x i64> @masked_zload_nxv8i16(<vscale x 8 x i16>* %a, <vscale x 8 x i1> %mask) {
+; CHECK-LABEL: masked_zload_nxv8i16:
+; CHECK:       punpklo p1.h, p0.b
+; CHECK-NEXT:  punpkhi p0.h, p0.b
+; CHECK-NEXT:  punpklo p2.h, p1.b
+; CHECK-NEXT:  punpkhi p1.h, p1.b
+; CHECK-NEXT:  ld1h { z0.d }, p2/z, [x0]
+; CHECK-NEXT:  punpklo p2.h, p0.b
+; CHECK-NEXT:  punpkhi p0.h, p0.b
+; CHECK-NEXT:  ld1h { z1.d }, p1/z, [x0, #1, mul vl]
+; CHECK-NEXT:  ld1h { z2.d }, p2/z, [x0, #2, mul vl]
+; CHECK-NEXT:  ld1h { z3.d }, p0/z, [x0, #3, mul vl]
+; CHECK-NEXT:  ret
+  %load = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16(<vscale x 8 x i16>* %a, i32 2, <vscale x 8 x i1> %mask, <vscale x 8 x i16> undef)
+  %ext = zext <vscale x 8 x i16> %load to <vscale x 8 x i64>
+  ret <vscale x 8 x i64> %ext
+}
+
 declare <vscale x 2 x i8> @llvm.masked.load.nxv2i8(<vscale x 2 x i8>*, i32, <vscale x 2 x i1>, <vscale x 2 x i8>)
 declare <vscale x 2 x i16> @llvm.masked.load.nxv2i16(<vscale x 2 x i16>*, i32, <vscale x 2 x i1>, <vscale x 2 x i16>)
 declare <vscale x 2 x i32> @llvm.masked.load.nxv2i32(<vscale x 2 x i32>*, i32, <vscale x 2 x i1>, <vscale x 2 x i32>)
 declare <vscale x 4 x i8> @llvm.masked.load.nxv4i8(<vscale x 4 x i8>*, i32, <vscale x 4 x i1>, <vscale x 4 x i8>)
 declare <vscale x 4 x i16> @llvm.masked.load.nxv4i16(<vscale x 4 x i16>*, i32, <vscale x 4 x i1>, <vscale x 4 x i16>)
 declare <vscale x 8 x i8> @llvm.masked.load.nxv8i8(<vscale x 8 x i8>*, i32, <vscale x 8 x i1>, <vscale x 8 x i8>)
+declare <vscale x 8 x i16> @llvm.masked.load.nxv8i16(<vscale x 8 x i16>*, i32, <vscale x 8 x i1>, <vscale x 8 x i16>)