[SVE] Add flag to specify SVE register size, using this to calculate legal vector...
authorPaul Walker <paul.walker@arm.com>
Fri, 13 Dec 2019 15:03:17 +0000 (15:03 +0000)
committerPaul Walker <paul.walker@arm.com>
Thu, 18 Jun 2020 12:11:16 +0000 (12:11 +0000)
Adds aarch64-sve-vector-bits-{min,max} to allow the size of SVE
data registers (in bits) to be specified. This allows the code
generator to make assumptions it normally couldn't. As a starting
point this information is used to mark fixed length vector types
that can fit within the specified size as legal.

Reviewers: rengolin, efriedma

Subscribers: tschuett, kristof.beyls, hiraditya, rkruppe, psnobl, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D80384

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
llvm/lib/Target/AArch64/AArch64Subtarget.cpp
llvm/lib/Target/AArch64/AArch64Subtarget.h
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
llvm/test/Analysis/CostModel/AArch64/sve-fixed-length.ll [new file with mode: 0644]

index 6177a80..8c82d88 100644 (file)
@@ -184,6 +184,16 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     addRegisterClass(MVT::nxv4f32, &AArch64::ZPRRegClass);
     addRegisterClass(MVT::nxv2f64, &AArch64::ZPRRegClass);
 
+    if (useSVEForFixedLengthVectors()) {
+      for (MVT VT : MVT::integer_fixedlen_vector_valuetypes())
+        if (useSVEForFixedLengthVectorVT(VT))
+          addRegisterClass(VT, &AArch64::ZPRRegClass);
+
+      for (MVT VT : MVT::fp_fixedlen_vector_valuetypes())
+        if (useSVEForFixedLengthVectorVT(VT))
+          addRegisterClass(VT, &AArch64::ZPRRegClass);
+    }
+
     for (auto VT : { MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64 }) {
       setOperationAction(ISD::SADDSAT, VT, Legal);
       setOperationAction(ISD::UADDSAT, VT, Legal);
@@ -3474,6 +3484,51 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
   }
 }
 
+bool AArch64TargetLowering::useSVEForFixedLengthVectors() const {
+  // Prefer NEON unless larger SVE registers are available.
+  return Subtarget->hasSVE() && Subtarget->getMinSVEVectorSizeInBits() >= 256;
+}
+
+bool AArch64TargetLowering::useSVEForFixedLengthVectorVT(MVT VT) const {
+  assert(VT.isFixedLengthVector());
+  if (!useSVEForFixedLengthVectors())
+    return false;
+
+  // Fixed length predicates should be promoted to i8.
+  // NOTE: This is consistent with how NEON (and thus 64/128bit vectors) work.
+  if (VT.getVectorElementType() == MVT::i1)
+    return false;
+
+  // Don't use SVE for vectors we cannot scalarize if required.
+  switch (VT.getVectorElementType().SimpleTy) {
+  default:
+    return false;
+  case MVT::i8:
+  case MVT::i16:
+  case MVT::i32:
+  case MVT::i64:
+  case MVT::f16:
+  case MVT::f32:
+  case MVT::f64:
+    break;
+  }
+
+  // Ensure NEON MVTs only belong to a single register class.
+  if (VT.getSizeInBits() <= 128)
+    return false;
+
+  // Don't use SVE for types that don't fit.
+  if (VT.getSizeInBits() > Subtarget->getMinSVEVectorSizeInBits())
+    return false;
+
+  // TODO: Perhaps an artificial restriction, but worth having whilst getting
+  // the base fixed length SVE support in place.
+  if (!VT.isPow2VectorType())
+    return false;
+
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 //                      Calling Convention Implementation
 //===----------------------------------------------------------------------===//
index 6d47fc8..0c0be2a 100644 (file)
@@ -907,6 +907,9 @@ private:
 
   bool shouldLocalize(const MachineInstr &MI,
                       const TargetTransformInfo *TTI) const override;
+
+  bool useSVEForFixedLengthVectors() const;
+  bool useSVEForFixedLengthVectorVT(MVT VT) const;
 };
 
 namespace AArch64 {
index 2eed344..9bb533f 100644 (file)
@@ -47,6 +47,18 @@ static cl::opt<bool>
                    cl::desc("Call nonlazybind functions via direct GOT load"),
                    cl::init(false), cl::Hidden);
 
+static cl::opt<unsigned> SVEVectorBitsMax(
+    "aarch64-sve-vector-bits-max",
+    cl::desc("Assume SVE vector registers are at most this big, "
+             "with zero meaning no maximum size is assumed."),
+    cl::init(0), cl::Hidden);
+
+static cl::opt<unsigned> SVEVectorBitsMin(
+    "aarch64-sve-vector-bits-min",
+    cl::desc("Assume SVE vector registers are at least this big, "
+             "with zero meaning no minimum size is assumed."),
+    cl::init(0), cl::Hidden);
+
 AArch64Subtarget &
 AArch64Subtarget::initializeSubtargetDependencies(StringRef FS,
                                                   StringRef CPUString) {
@@ -329,3 +341,25 @@ void AArch64Subtarget::mirFileLoaded(MachineFunction &MF) const {
   if (!MFI.isMaxCallFrameSizeComputed())
     MFI.computeMaxCallFrameSize(MF);
 }
+
+unsigned AArch64Subtarget::getMaxSVEVectorSizeInBits() const {
+  assert(HasSVE && "Tried to get SVE vector length without SVE support!");
+  assert(SVEVectorBitsMax % 128 == 0 &&
+         "SVE requires vector length in multiples of 128!");
+  assert((SVEVectorBitsMax >= SVEVectorBitsMin || SVEVectorBitsMax == 0) &&
+         "Minimum SVE vector size should not be larger than its maximum!");
+  if (SVEVectorBitsMax == 0)
+    return 0;
+  return (std::max(SVEVectorBitsMin, SVEVectorBitsMax) / 128) * 128;
+}
+
+unsigned AArch64Subtarget::getMinSVEVectorSizeInBits() const {
+  assert(HasSVE && "Tried to get SVE vector length without SVE support!");
+  assert(SVEVectorBitsMin % 128 == 0 &&
+         "SVE requires vector length in multiples of 128!");
+  assert((SVEVectorBitsMax >= SVEVectorBitsMin || SVEVectorBitsMax == 0) &&
+         "Minimum SVE vector size should not be larger than its maximum!");
+  if (SVEVectorBitsMax == 0)
+    return (SVEVectorBitsMin / 128) * 128;
+  return (std::min(SVEVectorBitsMin, SVEVectorBitsMax) / 128) * 128;
+}
index ba0660f..221c103 100644 (file)
@@ -534,6 +534,12 @@ public:
   }
 
   void mirFileLoaded(MachineFunction &MF) const override;
+
+  // Return the known range for the bit length of SVE data registers. A value
+  // of 0 means nothing is known about that particular limit beyong what's
+  // implied by the architecture.
+  unsigned getMaxSVEVectorSizeInBits() const;
+  unsigned getMinSVEVectorSizeInBits() const;
 };
 } // End llvm namespace
 
index f7233d3..8e85f92 100644 (file)
@@ -98,6 +98,8 @@ public:
 
   unsigned getRegisterBitWidth(bool Vector) const {
     if (Vector) {
+      if (ST->hasSVE())
+        return std::max(ST->getMinSVEVectorSizeInBits(), 128u);
       if (ST->hasNEON())
         return 128;
       return 0;
diff --git a/llvm/test/Analysis/CostModel/AArch64/sve-fixed-length.ll b/llvm/test/Analysis/CostModel/AArch64/sve-fixed-length.ll
new file mode 100644 (file)
index 0000000..ba7db1c
--- /dev/null
@@ -0,0 +1,60 @@
+; RUN: opt < %s -cost-model -analyze | FileCheck %s -D#VBITS=128
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=128 | FileCheck %s -D#VBITS=128
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=256 | FileCheck %s -D#VBITS=256
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=384 | FileCheck %s -D#VBITS=256
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=512 | FileCheck %s -D#VBITS=512
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=640 | FileCheck %s -D#VBITS=512
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=768 | FileCheck %s -D#VBITS=512
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=896 | FileCheck %s -D#VBITS=512
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1024 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1152 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1280 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1408 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1536 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1664 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1792 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1920 | FileCheck %s -D#VBITS=1024
+; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=2048 | FileCheck %s -D#VBITS=2048
+
+; VBITS represents the useful bit size of a vector register from the code
+; generator's point of view. It is clamped to power-of-2 values because
+; only power-of-2 vector lengths are considered legal, regardless of the
+; user specified vector length.
+
+target triple = "aarch64-unknown-linux-gnu"
+
+; Ensure the cost of legalisation is removed as the vector length grows.
+; NOTE: Assumes BaseCost_add=1, BaseCost_fadd=2.
+define void @add() #0 {
+; CHECK-LABEL: Printing analysis 'Cost Model Analysis' for function 'add':
+; CHECK: cost of [[#div(127,VBITS)+1]] for instruction:   %add128 = add <4 x i32> undef, undef
+; CHECK: cost of [[#div(255,VBITS)+1]] for instruction:   %add256 = add <8 x i32> undef, undef
+; CHECK: cost of [[#div(511,VBITS)+1]] for instruction:   %add512 = add <16 x i32> undef, undef
+; CHECK: cost of [[#div(1023,VBITS)+1]] for instruction:   %add1024 = add <32 x i32> undef, undef
+; CHECK: cost of [[#div(2047,VBITS)+1]] for instruction:   %add2048 = add <64 x i32> undef, undef
+  %add128 = add <4 x i32> undef, undef
+  %add256 = add <8 x i32> undef, undef
+  %add512 = add <16 x i32> undef, undef
+  %add1024 = add <32 x i32> undef, undef
+  %add2048 = add <64 x i32> undef, undef
+
+; Using a single vector length, ensure all element types are recognised.
+; CHECK: cost of [[#div(511,VBITS)+1]] for instruction:   %add512.i8 = add <64 x i8> undef, undef
+; CHECK: cost of [[#div(511,VBITS)+1]] for instruction:   %add512.i16 = add <32 x i16> undef, undef
+; CHECK: cost of [[#div(511,VBITS)+1]] for instruction:   %add512.i32 = add <16 x i32> undef, undef
+; CHECK: cost of [[#div(511,VBITS)+1]] for instruction:   %add512.i64 = add <8 x i64> undef, undef
+; CHECK: cost of [[#mul(div(511,VBITS)+1,2)]] for instruction:   %add512.f16 = fadd <32 x half> undef, undef
+; CHECK: cost of [[#mul(div(511,VBITS)+1,2)]] for instruction:   %add512.f32 = fadd <16 x float> undef, undef
+; CHECK: cost of [[#mul(div(511,VBITS)+1,2)]] for instruction:   %add512.f64 = fadd <8 x double> undef, undef
+  %add512.i8 = add <64 x i8> undef, undef
+  %add512.i16 = add <32 x i16> undef, undef
+  %add512.i32 = add <16 x i32> undef, undef
+  %add512.i64 = add <8 x i64> undef, undef
+  %add512.f16 = fadd <32 x half> undef, undef
+  %add512.f32 = fadd <16 x float> undef, undef
+  %add512.f64 = fadd <8 x double> undef, undef
+
+  ret void
+}
+
+attributes #0 = { "target-features"="+sve" }