[SCEV] Make scalable size representation more explicit
authorNikita Popov <npopov@redhat.com>
Wed, 22 Feb 2023 16:05:33 +0000 (17:05 +0100)
committerNikita Popov <npopov@redhat.com>
Mon, 27 Feb 2023 09:57:53 +0000 (10:57 +0100)
Represent scalable type sizes using C * vscale, where vscale is
the vscale constant expression. This exposes a bit more information
to SCEV, because the vscale multiplier is explicitly modeled in SCEV
(rather than part of the sizeof expression).

This is mainly intended as an alternative to D143642.

Differential Revision: https://reviews.llvm.org/D144624

llvm/include/llvm/Analysis/ScalarEvolution.h
llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
llvm/lib/Analysis/ScalarEvolution.cpp
llvm/test/Analysis/ScalarEvolution/scalable-vector.ll

index 51b67d1..548c6c5 100644 (file)
@@ -660,10 +660,8 @@ public:
     return getConstant(Ty, -1, /*isSigned=*/true);
   }
 
-  /// Return an expression for sizeof ScalableTy that is type IntTy, where
-  /// ScalableTy is a scalable vector type.
-  const SCEV *getSizeOfScalableVectorExpr(Type *IntTy,
-                                          ScalableVectorType *ScalableTy);
+  /// Return an expression for a TypeSize.
+  const SCEV *getSizeOfExpr(Type *IntTy, TypeSize Size);
 
   /// Return an expression for the alloc size of AllocTy that is type IntTy
   const SCEV *getSizeOfExpr(Type *IntTy, Type *AllocTy);
index 9245cf1..1b14d74 100644 (file)
@@ -579,14 +579,8 @@ class SCEVUnknown final : public SCEV, private CallbackVH {
 public:
   Value *getValue() const { return getValPtr(); }
 
-  /// @{
-  /// Test whether this is a special constant representing a type size in a
-  /// target-independent manner, and hasn't happened to have been folded with
-  /// other operations into something unrecognizable. This is mainly only
-  /// useful for pretty-printing and other situations where it isn't
-  /// absolutely required for these to succeed.
-  bool isSizeOf(Type *&AllocTy) const;
-  /// @}
+  /// Check whether this represents vscale.
+  bool isVScale() const;
 
   Type *getType() const { return getValPtr()->getType(); }
 
index 96cc518..a820879 100644 (file)
@@ -368,9 +368,8 @@ void SCEV::print(raw_ostream &OS) const {
   }
   case scUnknown: {
     const SCEVUnknown *U = cast<SCEVUnknown>(this);
-    Type *AllocTy;
-    if (U->isSizeOf(AllocTy)) {
-      OS << "sizeof(" << *AllocTy << ")";
+    if (U->isVScale()) {
+      OS << "vscale";
       return;
     }
 
@@ -561,20 +560,8 @@ void SCEVUnknown::allUsesReplacedWith(Value *New) {
   setValPtr(New);
 }
 
-bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
-  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
-    if (VCE->getOpcode() == Instruction::PtrToInt)
-      if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
-        if (CE->getOpcode() == Instruction::GetElementPtr &&
-            CE->getOperand(0)->isNullValue() &&
-            CE->getNumOperands() == 2)
-          if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
-            if (CI->isOne()) {
-              AllocTy = cast<GEPOperator>(CE)->getSourceElementType();
-              return true;
-            }
-
-  return false;
+bool SCEVUnknown::isVScale() const {
+  return match(getValue(), m_VScale());
 }
 
 //===----------------------------------------------------------------------===//
@@ -4326,33 +4313,26 @@ const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops,
 }
 
 const SCEV *
-ScalarEvolution::getSizeOfScalableVectorExpr(Type *IntTy,
-                                             ScalableVectorType *ScalableTy) {
-  Constant *NullPtr = Constant::getNullValue(ScalableTy->getPointerTo());
-  Constant *One = ConstantInt::get(IntTy, 1);
-  Constant *GEP = ConstantExpr::getGetElementPtr(ScalableTy, NullPtr, One);
-  // Note that the expression we created is the final expression, we don't
-  // want to simplify it any further Also, if we call a normal getSCEV(),
-  // we'll end up in an endless recursion. So just create an SCEVUnknown.
-  return getUnknown(ConstantExpr::getPtrToInt(GEP, IntTy));
+ScalarEvolution::getSizeOfExpr(Type *IntTy, TypeSize Size) {
+  const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
+  if (Size.isScalable()) {
+    // TODO: Why is there no ConstantExpr::getVScale()?
+    Type *SrcElemTy = ScalableVectorType::get(Type::getInt8Ty(getContext()), 1);
+    Constant *NullPtr = Constant::getNullValue(SrcElemTy->getPointerTo());
+    Constant *One = ConstantInt::get(IntTy, 1);
+    Constant *GEP = ConstantExpr::getGetElementPtr(SrcElemTy, NullPtr, One);
+    Constant *VScale = ConstantExpr::getPtrToInt(GEP, IntTy);
+    Res = getMulExpr(Res, getUnknown(VScale));
+  }
+  return Res;
 }
 
 const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
-  if (auto *ScalableAllocTy = dyn_cast<ScalableVectorType>(AllocTy))
-    return getSizeOfScalableVectorExpr(IntTy, ScalableAllocTy);
-  // We can bypass creating a target-independent constant expression and then
-  // folding it back into a ConstantInt. This is just a compile-time
-  // optimization.
-  return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
+  return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
 }
 
 const SCEV *ScalarEvolution::getStoreSizeOfExpr(Type *IntTy, Type *StoreTy) {
-  if (auto *ScalableStoreTy = dyn_cast<ScalableVectorType>(StoreTy))
-    return getSizeOfScalableVectorExpr(IntTy, ScalableStoreTy);
-  // We can bypass creating a target-independent constant expression and then
-  // folding it back into a ConstantInt. This is just a compile-time
-  // optimization.
-  return getConstant(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
+  return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
 }
 
 const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy,
index 798a360..b5bd724 100644 (file)
@@ -5,9 +5,9 @@ define void @a(ptr %p) {
 ; CHECK-LABEL: 'a'
 ; CHECK-NEXT:  Classifying expressions for: @a
 ; CHECK-NEXT:    %1 = getelementptr <vscale x 4 x i32>, ptr null, i32 3
-; CHECK-NEXT:    --> ((3 * sizeof(<vscale x 4 x i32>)) + null) U: [0,-15) S: [-9223372036854775808,9223372036854775793)
+; CHECK-NEXT:    --> ((48 * vscale) + null) U: [0,-15) S: [-9223372036854775808,9223372036854775793)
 ; CHECK-NEXT:    %2 = getelementptr <vscale x 1 x i64>, ptr %p, i32 1
-; CHECK-NEXT:    --> (sizeof(<vscale x 1 x i64>) + %p) U: full-set S: full-set
+; CHECK-NEXT:    --> ((8 * vscale) + %p) U: full-set S: full-set
 ; CHECK-NEXT:  Determining loop execution counts for: @a
 ;
   getelementptr <vscale x 4 x i32>, ptr null, i32 3