[SVE][NFC] Use ScalableVectorType in CGBuiltin
authorChristopher Tetreault <ctetreau@quicinc.com>
Mon, 27 Apr 2020 23:11:40 +0000 (16:11 -0700)
committerChristopher Tetreault <ctetreau@quicinc.com>
Mon, 27 Apr 2020 23:29:45 +0000 (16:29 -0700)
Summary: * Upgrade some usages of VectorType to use ScalableVectorType

Reviewers: efriedma, david-arm, fpetrogalli, kmclaughlin

Reviewed By: efriedma

Subscribers: tschuett, rkruppe, psnobl, cfe-commits

Tags: #clang

Differential Revision: https://reviews.llvm.org/D78842

clang/lib/CodeGen/CGBuiltin.cpp
clang/lib/CodeGen/CodeGenFunction.h

index b0e5eeb..ee99e35 100644 (file)
@@ -7542,66 +7542,68 @@ llvm::Type *CodeGenFunction::getEltType(SVETypeFlags TypeFlags) {
 
 // Return the llvm predicate vector type corresponding to the specified element
 // TypeFlags.
-llvm::VectorType* CodeGenFunction::getSVEPredType(SVETypeFlags TypeFlags) {
+llvm::ScalableVectorType *
+CodeGenFunction::getSVEPredType(SVETypeFlags TypeFlags) {
   switch (TypeFlags.getEltType()) {
   default: llvm_unreachable("Unhandled SVETypeFlag!");
 
   case SVETypeFlags::EltTyInt8:
-    return llvm::VectorType::get(Builder.getInt1Ty(), { 16, true });
+    return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 16);
   case SVETypeFlags::EltTyInt16:
-    return llvm::VectorType::get(Builder.getInt1Ty(), { 8, true });
+    return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 8);
   case SVETypeFlags::EltTyInt32:
-    return llvm::VectorType::get(Builder.getInt1Ty(), { 4, true });
+    return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 4);
   case SVETypeFlags::EltTyInt64:
-    return llvm::VectorType::get(Builder.getInt1Ty(), { 2, true });
+    return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 2);
 
   case SVETypeFlags::EltTyFloat16:
-    return llvm::VectorType::get(Builder.getInt1Ty(), { 8, true });
+    return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 8);
   case SVETypeFlags::EltTyFloat32:
-    return llvm::VectorType::get(Builder.getInt1Ty(), { 4, true });
+    return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 4);
   case SVETypeFlags::EltTyFloat64:
-    return llvm::VectorType::get(Builder.getInt1Ty(), { 2, true });
+    return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 2);
   }
 }
 
 // Return the llvm vector type corresponding to the specified element TypeFlags.
-llvm::VectorType *CodeGenFunction::getSVEType(const SVETypeFlags &TypeFlags) {
+llvm::ScalableVectorType *
+CodeGenFunction::getSVEType(const SVETypeFlags &TypeFlags) {
   switch (TypeFlags.getEltType()) {
   default:
     llvm_unreachable("Invalid SVETypeFlag!");
 
   case SVETypeFlags::EltTyInt8:
-    return llvm::VectorType::get(Builder.getInt8Ty(), {16, true});
+    return llvm::ScalableVectorType::get(Builder.getInt8Ty(), 16);
   case SVETypeFlags::EltTyInt16:
-    return llvm::VectorType::get(Builder.getInt16Ty(), {8, true});
+    return llvm::ScalableVectorType::get(Builder.getInt16Ty(), 8);
   case SVETypeFlags::EltTyInt32:
-    return llvm::VectorType::get(Builder.getInt32Ty(), {4, true});
+    return llvm::ScalableVectorType::get(Builder.getInt32Ty(), 4);
   case SVETypeFlags::EltTyInt64:
-    return llvm::VectorType::get(Builder.getInt64Ty(), {2, true});
+    return llvm::ScalableVectorType::get(Builder.getInt64Ty(), 2);
 
   case SVETypeFlags::EltTyFloat16:
-    return llvm::VectorType::get(Builder.getHalfTy(), {8, true});
+    return llvm::ScalableVectorType::get(Builder.getHalfTy(), 8);
   case SVETypeFlags::EltTyFloat32:
-    return llvm::VectorType::get(Builder.getFloatTy(), {4, true});
+    return llvm::ScalableVectorType::get(Builder.getFloatTy(), 4);
   case SVETypeFlags::EltTyFloat64:
-    return llvm::VectorType::get(Builder.getDoubleTy(), {2, true});
+    return llvm::ScalableVectorType::get(Builder.getDoubleTy(), 2);
 
   case SVETypeFlags::EltTyBool8:
-    return llvm::VectorType::get(Builder.getInt1Ty(), {16, true});
+    return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 16);
   case SVETypeFlags::EltTyBool16:
-    return llvm::VectorType::get(Builder.getInt1Ty(), {8, true});
+    return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 8);
   case SVETypeFlags::EltTyBool32:
-    return llvm::VectorType::get(Builder.getInt1Ty(), {4, true});
+    return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 4);
   case SVETypeFlags::EltTyBool64:
-    return llvm::VectorType::get(Builder.getInt1Ty(), {2, true});
+    return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 2);
   }
 }
 
 constexpr unsigned SVEBitsPerBlock = 128;
 
-static llvm::VectorType* getSVEVectorForElementType(llvm::Type *EltTy) {
+static llvm::ScalableVectorType *getSVEVectorForElementType(llvm::Type *EltTy) {
   unsigned NumElts = SVEBitsPerBlock / EltTy->getScalarSizeInBits();
-  return llvm::VectorType::get(EltTy, { NumElts, true });
+  return llvm::ScalableVectorType::get(EltTy, NumElts);
 }
 
 // Reinterpret the input predicate so that it can be used to correctly isolate
@@ -7640,8 +7642,8 @@ Value *CodeGenFunction::EmitSVEGatherLoad(SVETypeFlags TypeFlags,
                                           SmallVectorImpl<Value *> &Ops,
                                           unsigned IntID) {
   auto *ResultTy = getSVEType(TypeFlags);
-  auto *OverloadedTy = llvm::VectorType::get(SVEBuiltinMemEltTy(TypeFlags),
-                                             ResultTy->getElementCount());
+  auto *OverloadedTy =
+      llvm::ScalableVectorType::get(SVEBuiltinMemEltTy(TypeFlags), ResultTy);
 
   // At the ACLE level there's only one predicate type, svbool_t, which is
   // mapped to <n x 16 x i1>. However, this might be incompatible with the
@@ -7692,8 +7694,8 @@ Value *CodeGenFunction::EmitSVEScatterStore(SVETypeFlags TypeFlags,
                                             SmallVectorImpl<Value *> &Ops,
                                             unsigned IntID) {
   auto *SrcDataTy = getSVEType(TypeFlags);
-  auto *OverloadedTy = llvm::VectorType::get(SVEBuiltinMemEltTy(TypeFlags),
-                                             SrcDataTy->getElementCount());
+  auto *OverloadedTy =
+      llvm::ScalableVectorType::get(SVEBuiltinMemEltTy(TypeFlags), SrcDataTy);
 
   // In ACLE the source data is passed in the last argument, whereas in LLVM IR
   // it's the first argument. Move it accordingly.
@@ -7748,7 +7750,7 @@ Value *CodeGenFunction::EmitSVEPrefetchLoad(SVETypeFlags TypeFlags,
                                             unsigned BuiltinID) {
   auto *MemEltTy = SVEBuiltinMemEltTy(TypeFlags);
   auto *VectorTy = getSVEVectorForElementType(MemEltTy);
-  auto *MemoryTy = llvm::VectorType::get(MemEltTy, VectorTy->getElementCount());
+  auto *MemoryTy = llvm::ScalableVectorType::get(MemEltTy, VectorTy);
 
   Value *Predicate = EmitSVEPredicateCast(Ops[0], MemoryTy);
   Value *BasePtr = Ops[1];
@@ -7778,8 +7780,8 @@ Value *CodeGenFunction::EmitSVEMaskedLoad(const CallExpr *E,
 
   // The vector type that is returned may be different from the
   // eventual type loaded from memory.
-  auto VectorTy = cast<llvm::VectorType>(ReturnTy);
-  auto MemoryTy = llvm::VectorType::get(MemEltTy, VectorTy->getElementCount());
+  auto VectorTy = cast<llvm::ScalableVectorType>(ReturnTy);
+  auto MemoryTy = llvm::ScalableVectorType::get(MemEltTy, VectorTy);
 
   Value *Predicate = EmitSVEPredicateCast(Ops[0], MemoryTy);
   Value *BasePtr = Builder.CreateBitCast(Ops[1], MemoryTy->getPointerTo());
@@ -7803,8 +7805,8 @@ Value *CodeGenFunction::EmitSVEMaskedStore(const CallExpr *E,
 
   // The vector type that is stored may be different from the
   // eventual type stored to memory.
-  auto VectorTy = cast<llvm::VectorType>(Ops.back()->getType());
-  auto MemoryTy = llvm::VectorType::get(MemEltTy, VectorTy->getElementCount());
+  auto VectorTy = cast<llvm::ScalableVectorType>(Ops.back()->getType());
+  auto MemoryTy = llvm::ScalableVectorType::get(MemEltTy, VectorTy);
 
   Value *Predicate = EmitSVEPredicateCast(Ops[0], MemoryTy);
   Value *BasePtr = Builder.CreateBitCast(Ops[1], MemoryTy->getPointerTo());
index 53809b6..1af83b8 100644 (file)
@@ -3911,8 +3911,8 @@ public:
   SmallVector<llvm::Type *, 2> getSVEOverloadTypes(SVETypeFlags TypeFlags,
                                                    ArrayRef<llvm::Value *> Ops);
   llvm::Type *getEltType(SVETypeFlags TypeFlags);
-  llvm::VectorType *getSVEType(const SVETypeFlags &TypeFlags);
-  llvm::VectorType *getSVEPredType(SVETypeFlags TypeFlags);
+  llvm::ScalableVectorType *getSVEType(const SVETypeFlags &TypeFlags);
+  llvm::ScalableVectorType *getSVEPredType(SVETypeFlags TypeFlags);
   llvm::Value *EmitSVEDupX(llvm::Value *Scalar);
   llvm::Value *EmitSVEPredicateCast(llvm::Value *Pred, llvm::VectorType *VTy);
   llvm::Value *EmitSVEGatherLoad(SVETypeFlags TypeFlags,