From 0d153df69e8fe28bdf7e65195d3708f331106088 Mon Sep 17 00:00:00 2001 From: Kerry McLaughlin Date: Thu, 21 Oct 2021 11:30:31 +0100 Subject: [PATCH] [SVE] Fix selection failure when splitting extended masked loads When splitting a masked load, `GetDependentSplitDestVTs` is used to get the MemVTs of the high and low parts. If the masked load is extended, this may return VTs with different element types which are used to create the high & low masked load instructions. This patch changes `GetDependentSplitDestVTs` to ensure we return VTs with the same element type. Reviewed By: david-arm Differential Revision: https://reviews.llvm.org/D111996 --- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 4 ++-- llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll | 20 ++++++++++++++++++++ llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll | 20 ++++++++++++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 0a5dce3..b928fd3 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -10642,14 +10642,14 @@ SelectionDAG::GetDependentSplitDestVTs(const EVT &VT, const EVT &EnvVT, "Mixing fixed width and scalable vectors when enveloping a type"); EVT LoVT, HiVT; if (VTNumElts.getKnownMinValue() > EnvNumElts.getKnownMinValue()) { - LoVT = EnvVT; + LoVT = EVT::getVectorVT(*getContext(), EltTp, EnvNumElts); HiVT = EVT::getVectorVT(*getContext(), EltTp, VTNumElts - EnvNumElts); *HiIsEmpty = false; } else { // Flag that hi type has zero storage size, but return split envelop type // (this would be easier if vector types with zero elements were allowed). LoVT = EVT::getVectorVT(*getContext(), EltTp, VTNumElts); - HiVT = EnvVT; + HiVT = EVT::getVectorVT(*getContext(), EltTp, EnvNumElts); *HiIsEmpty = true; } return std::make_pair(LoVT, HiVT); diff --git a/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll b/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll index f7efa54..9a20035 100644 --- a/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll +++ b/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll @@ -70,9 +70,29 @@ define @masked_sload_passthru( *%a, %ext } +; Return type requires splitting +define @masked_sload_nxv16i8(* %a, %mask) { +; CHECK-LABEL: masked_sload_nxv16i8: +; CHECK: punpklo p1.h, p0.b +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: punpklo p2.h, p1.b +; CHECK-NEXT: punpkhi p1.h, p1.b +; CHECK-NEXT: ld1sb { z0.s }, p2/z, [x0] +; CHECK-NEXT: punpklo p2.h, p0.b +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: ld1sb { z1.s }, p1/z, [x0, #1, mul vl] +; CHECK-NEXT: ld1sb { z2.s }, p2/z, [x0, #2, mul vl] +; CHECK-NEXT: ld1sb { z3.s }, p0/z, [x0, #3, mul vl] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv16i8(* %a, i32 2, %mask, undef) + %ext = sext %load to + ret %ext +} + declare @llvm.masked.load.nxv2i8(*, i32, , ) declare @llvm.masked.load.nxv2i16(*, i32, , ) declare @llvm.masked.load.nxv2i32(*, i32, , ) declare @llvm.masked.load.nxv4i8(*, i32, , ) declare @llvm.masked.load.nxv4i16(*, i32, , ) declare @llvm.masked.load.nxv8i8(*, i32, , ) +declare @llvm.masked.load.nxv16i8(*, i32, , ) diff --git a/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll b/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll index 7dbebee..79eff4d 100644 --- a/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll +++ b/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll @@ -76,9 +76,29 @@ define @masked_zload_passthru(* %src, %ext } +; Return type requires splitting +define @masked_zload_nxv8i16(* %a, %mask) { +; CHECK-LABEL: masked_zload_nxv8i16: +; CHECK: punpklo p1.h, p0.b +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: punpklo p2.h, p1.b +; CHECK-NEXT: punpkhi p1.h, p1.b +; CHECK-NEXT: ld1h { z0.d }, p2/z, [x0] +; CHECK-NEXT: punpklo p2.h, p0.b +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: ld1h { z1.d }, p1/z, [x0, #1, mul vl] +; CHECK-NEXT: ld1h { z2.d }, p2/z, [x0, #2, mul vl] +; CHECK-NEXT: ld1h { z3.d }, p0/z, [x0, #3, mul vl] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv8i16(* %a, i32 2, %mask, undef) + %ext = zext %load to + ret %ext +} + declare @llvm.masked.load.nxv2i8(*, i32, , ) declare @llvm.masked.load.nxv2i16(*, i32, , ) declare @llvm.masked.load.nxv2i32(*, i32, , ) declare @llvm.masked.load.nxv4i8(*, i32, , ) declare @llvm.masked.load.nxv4i16(*, i32, , ) declare @llvm.masked.load.nxv8i8(*, i32, , ) +declare @llvm.masked.load.nxv8i16(*, i32, , ) -- 2.7.4