From b62e6f19d71359f2c901c834764191355ad06420 Mon Sep 17 00:00:00 2001 From: Sander de Smalen Date: Thu, 16 Sep 2021 16:03:52 +0100 Subject: [PATCH] [SelectionDAG] Handle promotion + widening in getCopyToPartsVector Some vectors require both widening and promotion for their legalization. This case is not yet handled in getCopyToPartsVector and falls back on scalarizing by default. BBecause scalable vectors can't easily be scalarised, we need to implement this in two separate stages: 1. Widen the vector. 2. Promote the vector. As part of this patch, PromoteIntRes_CONCAT_VECTORS also needed to be made scalable aware. Instead of falling back on scalarizing the vector (fixed-width only), each sub-part of the CONCAT vector is promoted, and the operation is performed on the type with the widest element type, finally truncating the result to the promoted result type. Differential Revision: https://reviews.llvm.org/D110646 --- .../CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp | 40 ++++++++++++++++++++-- .../CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 11 ++++++ .../CodeGen/AArch64/sve-extract-scalable-vector.ll | 33 ++++++++++++++++++ 3 files changed, 82 insertions(+), 2 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp index 076dbf7..113f624 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -23,6 +23,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" +#include using namespace llvm; #define DEBUG_TYPE "legalize-types" @@ -5057,11 +5058,46 @@ SDValue DAGTypeLegalizer::PromoteIntRes_CONCAT_VECTORS(SDNode *N) { EVT NOutVT = TLI.getTypeToTransformTo(*DAG.getContext(), OutVT); assert(NOutVT.isVector() && "This type must be promoted to a vector type"); + unsigned NumOperands = N->getNumOperands(); + unsigned NumOutElem = NOutVT.getVectorMinNumElements(); EVT OutElemTy = NOutVT.getVectorElementType(); + if (OutVT.isScalableVector()) { + // Find the largest promoted element type for each of the operands. + SDUse *MaxSizedValue = std::max_element( + N->op_begin(), N->op_end(), [](const SDValue &A, const SDValue &B) { + EVT AVT = A.getValueType().getVectorElementType(); + EVT BVT = B.getValueType().getVectorElementType(); + return AVT.getScalarSizeInBits() < BVT.getScalarSizeInBits(); + }); + EVT MaxElementVT = MaxSizedValue->getValueType().getVectorElementType(); + + // Then promote all vectors to the largest element type. + SmallVector Ops; + for (unsigned I = 0; I < NumOperands; ++I) { + SDValue Op = N->getOperand(I); + EVT OpVT = Op.getValueType(); + if (getTypeAction(OpVT) == TargetLowering::TypePromoteInteger) + Op = GetPromotedInteger(Op); + else + assert(getTypeAction(OpVT) == TargetLowering::TypeLegal && + "Unhandled legalization type"); + + if (OpVT.getVectorElementType().getScalarSizeInBits() < + MaxElementVT.getScalarSizeInBits()) + Op = DAG.getAnyExtOrTrunc(Op, dl, + OpVT.changeVectorElementType(MaxElementVT)); + Ops.push_back(Op); + } + + // Do the CONCAT on the promoted type and finally truncate to (the promoted) + // NOutVT. + return DAG.getAnyExtOrTrunc( + DAG.getNode(ISD::CONCAT_VECTORS, dl, + OutVT.changeVectorElementType(MaxElementVT), Ops), + dl, NOutVT); + } unsigned NumElem = N->getOperand(0).getValueType().getVectorNumElements(); - unsigned NumOutElem = NOutVT.getVectorNumElements(); - unsigned NumOperands = N->getNumOperands(); assert(NumElem * NumOperands == NumOutElem && "Unexpected number of elements"); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 83c3c16..77e334a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -673,6 +673,17 @@ static void getCopyToPartsVector(SelectionDAG &DAG, const SDLoc &DL, // Promoted vector extract Val = DAG.getAnyExtOrTrunc(Val, DL, PartVT); + } else if (PartEVT.isVector() && + PartEVT.getVectorElementType() != + ValueVT.getVectorElementType() && + TLI.getTypeAction(*DAG.getContext(), ValueVT) == + TargetLowering::TypeWidenVector) { + // Combination of widening and promotion. + EVT WidenVT = + EVT::getVectorVT(*DAG.getContext(), ValueVT.getVectorElementType(), + PartVT.getVectorElementCount()); + SDValue Widened = widenVectorToPartType(DAG, Val, DL, WidenVT); + Val = DAG.getAnyExtOrTrunc(Widened, DL, PartVT); } else { if (ValueVT.getVectorElementCount().isScalar()) { Val = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, PartVT, Val, diff --git a/llvm/test/CodeGen/AArch64/sve-extract-scalable-vector.ll b/llvm/test/CodeGen/AArch64/sve-extract-scalable-vector.ll index 71a1964..0dc12bf 100644 --- a/llvm/test/CodeGen/AArch64/sve-extract-scalable-vector.ll +++ b/llvm/test/CodeGen/AArch64/sve-extract-scalable-vector.ll @@ -440,6 +440,39 @@ define @extract_nxv4i8_nxv12i8_8( %in) { declare @llvm.experimental.vector.extract.nxv4i8.nxv12i8(, i64) ; +; Extract i8 vector that needs both widening + promotion from one that needs widening. +; (nxv6i8 -> nxv8i8 -> nxv8i16) +; +define @extract_nxv6i8_nxv12i8_0( %in) { +; CHECK-LABEL: extract_nxv6i8_nxv12i8_0: +; CHECK: // %bb.0: +; CHECK-NEXT: uunpklo z0.h, z0.b +; CHECK-NEXT: ret + %res = call @llvm.experimental.vector.extract.nxv6i8.nxv12i8( %in, i64 0) + ret %res +} + +define @extract_nxv6i8_nxv12i8_6( %in) { +; CHECK-LABEL: extract_nxv6i8_nxv12i8_6: +; CHECK: // %bb.0: +; CHECK-NEXT: uunpkhi z1.h, z0.b +; CHECK-NEXT: uunpklo z0.h, z0.b +; CHECK-NEXT: uunpklo z1.s, z1.h +; CHECK-NEXT: uunpkhi z0.s, z0.h +; CHECK-NEXT: uunpkhi z2.d, z1.s +; CHECK-NEXT: uunpklo z1.d, z1.s +; CHECK-NEXT: uunpkhi z0.d, z0.s +; CHECK-NEXT: uzp1 z2.s, z2.s, z0.s +; CHECK-NEXT: uzp1 z0.s, z0.s, z1.s +; CHECK-NEXT: uzp1 z0.h, z0.h, z2.h +; CHECK-NEXT: ret + %res = call @llvm.experimental.vector.extract.nxv6i8.nxv12i8( %in, i64 6) + ret %res +} + +declare @llvm.experimental.vector.extract.nxv6i8.nxv12i8(, i64) + +; ; Extract half i8 vector that needs promotion from one that needs splitting. ; define @extract_nxv8i8_nxv32i8_0( %in) { -- 2.7.4