[IR] Support scalable vectors in CastInst::CreatePointerCast
authorCullen Rhodes <cullen.rhodes@arm.com>
Wed, 2 Dec 2020 13:21:00 +0000 (13:21 +0000)
committerCullen Rhodes <cullen.rhodes@arm.com>
Wed, 9 Dec 2020 10:39:36 +0000 (10:39 +0000)
Reviewed By: sdesmalen

Differential Revision: https://reviews.llvm.org/D92482

llvm/lib/IR/Instructions.cpp
llvm/unittests/IR/InstructionsTest.cpp

index 2827466..74a95da 100644 (file)
@@ -3022,8 +3022,8 @@ CastInst *CastInst::CreatePointerCast(Value *S, Type *Ty,
          "Invalid cast");
   assert(Ty->isVectorTy() == S->getType()->isVectorTy() && "Invalid cast");
   assert((!Ty->isVectorTy() ||
-          cast<FixedVectorType>(Ty)->getNumElements() ==
-              cast<FixedVectorType>(S->getType())->getNumElements()) &&
+          cast<VectorType>(Ty)->getElementCount() ==
+              cast<VectorType>(S->getType())->getElementCount()) &&
          "Invalid cast");
 
   if (Ty->isIntOrIntVectorTy())
@@ -3041,8 +3041,8 @@ CastInst *CastInst::CreatePointerCast(Value *S, Type *Ty,
          "Invalid cast");
   assert(Ty->isVectorTy() == S->getType()->isVectorTy() && "Invalid cast");
   assert((!Ty->isVectorTy() ||
-          cast<FixedVectorType>(Ty)->getNumElements() ==
-              cast<FixedVectorType>(S->getType())->getNumElements()) &&
+          cast<VectorType>(Ty)->getElementCount() ==
+              cast<VectorType>(S->getType())->getElementCount()) &&
          "Invalid cast");
 
   if (Ty->isIntOrIntVectorTy())
index 7a4dad2..419cddc 100644 (file)
@@ -368,11 +368,19 @@ TEST(InstructionsTest, CastInst) {
   Constant *NullV2I32Ptr = Constant::getNullValue(V2Int32PtrTy);
   auto Inst1 = CastInst::CreatePointerCast(NullV2I32Ptr, V2Int32Ty, "foo", BB);
 
+  Constant *NullVScaleV2I32Ptr = Constant::getNullValue(VScaleV2Int32PtrTy);
+  auto Inst1VScale = CastInst::CreatePointerCast(
+      NullVScaleV2I32Ptr, VScaleV2Int32Ty, "foo.vscale", BB);
+
   // Second form
   auto Inst2 = CastInst::CreatePointerCast(NullV2I32Ptr, V2Int32Ty);
+  auto Inst2VScale =
+      CastInst::CreatePointerCast(NullVScaleV2I32Ptr, VScaleV2Int32Ty);
 
   delete Inst2;
+  delete Inst2VScale;
   Inst1->eraseFromParent();
+  Inst1VScale->eraseFromParent();
   delete BB;
 }