[SVE][CodeGen] Improve codegen for some zero-extends of masked loads
authorDavid Sherwood <david.sherwood@arm.com>
Thu, 13 Jul 2023 15:21:31 +0000 (15:21 +0000)
committerDavid Sherwood <david.sherwood@arm.com>
Mon, 17 Jul 2023 08:19:27 +0000 (08:19 +0000)
When doing a masked load of an illegal unpacked type and then
zero-extending to some illegal wider types we sometimes end up
with pointless 'and' instructions that are trying to zero bits
that we already know are zero. This patch fixes that by adding
more cases to performSVEAndCombine.

Differential Revision: https://reviews.llvm.org/D155281

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/test/CodeGen/AArch64/sve-intrinsics-mask-ldst-ext.ll

index 2a2953c..8059592 100644 (file)
@@ -17053,14 +17053,28 @@ static SDValue performSVEAndCombine(SDNode *N,
 
     uint64_t ExtVal = C->getZExtValue();
 
+    auto MaskAndTypeMatch = [ExtVal](EVT VT) -> bool {
+      return ((ExtVal == 0xFF && VT == MVT::i8) ||
+              (ExtVal == 0xFFFF && VT == MVT::i16) ||
+              (ExtVal == 0xFFFFFFFF && VT == MVT::i32));
+    };
+
     // If the mask is fully covered by the unpack, we don't need to push
     // a new AND onto the operand
     EVT EltTy = UnpkOp->getValueType(0).getVectorElementType();
-    if ((ExtVal == 0xFF && EltTy == MVT::i8) ||
-        (ExtVal == 0xFFFF && EltTy == MVT::i16) ||
-        (ExtVal == 0xFFFFFFFF && EltTy == MVT::i32))
+    if (MaskAndTypeMatch(EltTy))
       return Src;
 
+    // If this is 'and (uunpklo/hi (extload MemTy -> ExtTy)), mask', then check
+    // to see if the mask is all-ones of size MemTy.
+    auto MaskedLoadOp = dyn_cast<MaskedLoadSDNode>(UnpkOp);
+    if (MaskedLoadOp && (MaskedLoadOp->getExtensionType() == ISD::ZEXTLOAD ||
+                         MaskedLoadOp->getExtensionType() == ISD::EXTLOAD)) {
+      EVT EltTy = MaskedLoadOp->getMemoryVT().getVectorElementType();
+      if (MaskAndTypeMatch(EltTy))
+        return Src;
+    }
+
     // Truncate to prevent a DUP with an over wide constant
     APInt Mask = C->getAPIntValue().trunc(EltTy.getSizeInBits());
 
index 55bd383..c495e98 100644 (file)
@@ -52,10 +52,9 @@ define <vscale x 16 x i32> @masked_ld1b_i8_zext_i32(<vscale x 16 x i8> *%base, <
 define <vscale x 8 x i32> @masked_ld1b_nxv8i8_zext_i32(<vscale x 8 x i8> *%a, <vscale x 8 x i1> %mask) {
 ; CHECK-LABEL: masked_ld1b_nxv8i8_zext_i32:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ld1b { z0.h }, p0/z, [x0]
-; CHECK-NEXT:    uunpkhi z1.s, z0.h
-; CHECK-NEXT:    and z0.h, z0.h, #0xff
-; CHECK-NEXT:    uunpklo z0.s, z0.h
+; CHECK-NEXT:    ld1b { z1.h }, p0/z, [x0]
+; CHECK-NEXT:    uunpklo z0.s, z1.h
+; CHECK-NEXT:    uunpkhi z1.s, z1.h
 ; CHECK-NEXT:    ret
   %wide.masked.load = call <vscale x 8 x i8> @llvm.masked.load.nxv8i8.p0(ptr %a, i32 1, <vscale x 8 x i1> %mask, <vscale x 8 x i8> poison)
   %res = zext <vscale x 8 x i8> %wide.masked.load to <vscale x 8 x i32>
@@ -125,10 +124,9 @@ define <vscale x 16 x i64> @masked_ld1b_i8_zext(<vscale x 16 x i8> *%base, <vsca
 define <vscale x 4 x i64> @masked_ld1b_nxv4i8_zext_i64(<vscale x 4 x i8> *%a, <vscale x 4 x i1> %mask) {
 ; CHECK-LABEL: masked_ld1b_nxv4i8_zext_i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ld1b { z0.s }, p0/z, [x0]
-; CHECK-NEXT:    uunpkhi z1.d, z0.s
-; CHECK-NEXT:    and z0.s, z0.s, #0xff
-; CHECK-NEXT:    uunpklo z0.d, z0.s
+; CHECK-NEXT:    ld1b { z1.s }, p0/z, [x0]
+; CHECK-NEXT:    uunpklo z0.d, z1.s
+; CHECK-NEXT:    uunpkhi z1.d, z1.s
 ; CHECK-NEXT:    ret
   %wide.masked.load = call <vscale x 4 x i8> @llvm.masked.load.nxv4i8.p0(ptr %a, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x i8> poison)
   %res = zext <vscale x 4 x i8> %wide.masked.load to <vscale x 4 x i64>
@@ -186,10 +184,9 @@ define <vscale x 8 x i64> @masked_ld1h_i16_zext(<vscale x 8 x i16> *%base, <vsca
 define <vscale x 4 x i64> @masked_ld1h_nxv4i16_zext(<vscale x 4 x i16> *%a, <vscale x 4 x i1> %mask) {
 ; CHECK-LABEL: masked_ld1h_nxv4i16_zext:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ld1h { z0.s }, p0/z, [x0]
-; CHECK-NEXT:    uunpkhi z1.d, z0.s
-; CHECK-NEXT:    and z0.s, z0.s, #0xffff
-; CHECK-NEXT:    uunpklo z0.d, z0.s
+; CHECK-NEXT:    ld1h { z1.s }, p0/z, [x0]
+; CHECK-NEXT:    uunpklo z0.d, z1.s
+; CHECK-NEXT:    uunpkhi z1.d, z1.s
 ; CHECK-NEXT:    ret
   %wide.masked.load = call <vscale x 4 x i16> @llvm.masked.load.nxv4i16.p0(ptr %a, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x i16> poison)
   %res = zext <vscale x 4 x i16> %wide.masked.load to <vscale x 4 x i64>