[mlir][VectorType] Remove `numScalableDims` from the vector type
authorAndrzej Warzynski <andrzej.warzynski@arm.com>
Wed, 21 Jun 2023 12:27:13 +0000 (13:27 +0100)
committerAndrzej Warzynski <andrzej.warzynski@gmail.com>
Wed, 28 Jun 2023 12:53:45 +0000 (13:53 +0100)
This is a follow-up of https://reviews.llvm.org/D153372 in which
`numScalableDims` (single integer) was effectively replaced with
`isScalableDim` bitmask.

This change is a part of a larger effort to enable scalable
vectorisation in Linalg. See this RFC for more context:
  * https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/

Differential Revision: https://reviews.llvm.org/D153412

16 files changed:
mlir/include/mlir/IR/BuiltinDialectBytecode.td
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/BuiltinTypes.td
mlir/lib/AsmParser/Parser.h
mlir/lib/AsmParser/TypeParser.cpp
mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/lib/IR/BuiltinTypes.cpp

index 40e6f04..fcbb5f4 100644 (file)
@@ -275,18 +275,17 @@ def VectorType : DialectType<(type
   Array<SignedVarIntList>:$shape,
   Type:$elementType
 )> {
-  let printerPredicate = "!$_val.getNumScalableDims()";
+  let printerPredicate = "!$_val.isScalable()";
 }
 
 def VectorTypeWithScalableDims : DialectType<(type
   Array<BoolList>:$scalableDims,
-  VarInt:$numScalableDims,
   Array<SignedVarIntList>:$shape,
   Type:$elementType
 )> {
-  let printerPredicate = "$_val.getNumScalableDims()";
+  let printerPredicate = "$_val.isScalable()";
   // Note: order of serialization does not match order of builder.
-  let cBuilder = "get<$_resultType>(context, shape, elementType, numScalableDims, scalableDims)";
+  let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
 }
 }
 
index 1fd869b..f22421a 100644 (file)
@@ -306,23 +306,20 @@ public:
   /// Build from another VectorType.
   explicit Builder(VectorType other)
       : shape(other.getShape()), elementType(other.getElementType()),
-        numScalableDims(other.getNumScalableDims()),
         scalableDims(other.getScalableDims()) {}
 
   /// Build from scratch.
   Builder(ArrayRef<int64_t> shape, Type elementType,
           unsigned numScalableDims = 0, ArrayRef<bool> scalableDims = {})
-      : shape(shape), elementType(elementType),
-        numScalableDims(numScalableDims) {
+      : shape(shape), elementType(elementType) {
     if (scalableDims.empty())
       scalableDims = SmallVector<bool>(shape.size(), false);
     else
       this->scalableDims = scalableDims;
   }
 
-  Builder &setShape(ArrayRef<int64_t> newShape, unsigned newNumScalableDims = 0,
+  Builder &setShape(ArrayRef<int64_t> newShape,
                     ArrayRef<bool> newIsScalableDim = {}) {
-    numScalableDims = newNumScalableDims;
     if (newIsScalableDim.empty())
       scalableDims = SmallVector<bool>(shape.size(), false);
     else
@@ -340,8 +337,6 @@ public:
   /// Erase a dim from shape @pos.
   Builder &dropDim(unsigned pos) {
     assert(pos < shape.size() && "overflow");
-    if (pos >= shape.size() - numScalableDims)
-      numScalableDims--;
     if (storage.empty())
       storage.append(shape.begin(), shape.end());
     if (storageScalableDims.empty())
@@ -360,7 +355,7 @@ public:
   operator Type() {
     if (shape.empty())
       return elementType;
-    return VectorType::get(shape, elementType, numScalableDims, scalableDims);
+    return VectorType::get(shape, elementType, scalableDims);
   }
 
 private:
@@ -368,7 +363,6 @@ private:
   // Owning shape data for copy-on-write operations.
   SmallVector<int64_t> storage;
   Type elementType;
-  unsigned numScalableDims;
   ArrayRef<bool> scalableDims;
   // Owning scalableDims data for copy-on-write operations.
   SmallVector<bool> storageScalableDims;
index dead629..900531b 100644 (file)
@@ -1066,13 +1066,11 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
     "Type":$elementType,
-    "unsigned":$numScalableDims,
     ArrayRefParameter<"bool">:$scalableDims
   );
   let builders = [
     TypeBuilderWithInferredContext<(ins
       "ArrayRef<int64_t>":$shape, "Type":$elementType,
-      CArg<"unsigned", "0">:$numScalableDims,
       CArg<"ArrayRef<bool>", "{}">:$scalableDims
     ), [{
       // While `scalableDims` is optional, its default value should be
@@ -1082,8 +1080,7 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
         isScalableVec.resize(shape.size(), false);
         scalableDims = isScalableVec;
       }
-      return $_get(elementType.getContext(), shape, elementType,
-                   numScalableDims, scalableDims);
+      return $_get(elementType.getContext(), shape, elementType, scalableDims);
     }]>
   ];
   let extraClassDeclaration = [{
@@ -1100,7 +1097,13 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
 
     /// Returns true if the vector contains scalable dimensions.
     bool isScalable() const {
-      return getNumScalableDims() > 0;
+      return llvm::is_contained(getScalableDims(), true);
+    }
+    bool allDimsScalable() const {
+      // Treat 0-d vectors as fixed size.
+      if (getRank() == 0)
+        return false;
+      return !llvm::is_contained(getScalableDims(), false);
     }
 
     /// Get or create a new VectorType with the same shape as `this` and an
index 655412d..9704cea 100644 (file)
@@ -211,7 +211,6 @@ public:
   /// Parse a vector type.
   VectorType parseVectorType();
   ParseResult parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
-                                       unsigned &numScalableDims,
                                        SmallVectorImpl<bool> &scalableDims);
   ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
                                        bool allowDynamic = true,
index 6eeea41..6a65dda 100644 (file)
@@ -441,8 +441,7 @@ VectorType Parser::parseVectorType() {
 
   SmallVector<int64_t, 4> dimensions;
   SmallVector<bool, 4> scalableDims;
-  unsigned numScalableDims;
-  if (parseVectorDimensionList(dimensions, numScalableDims, scalableDims))
+  if (parseVectorDimensionList(dimensions, scalableDims))
     return nullptr;
   if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
     return emitError(getToken().getLoc(),
@@ -459,16 +458,13 @@ VectorType Parser::parseVectorType() {
     return emitError(typeLoc, "vector elements must be int/index/float type"),
            nullptr;
 
-  return VectorType::get(dimensions, elementType, numScalableDims,
-                         scalableDims);
+  return VectorType::get(dimensions, elementType, scalableDims);
 }
 
 /// Parse a dimension list in a vector type. This populates the dimension list.
 /// For i-th dimension, `scalableDims[i]` contains either:
 ///   * `false` for a non-scalable dimension (e.g. `4`),
 ///   * `true` for a scalable dimension (e.g. `[4]`).
-/// This method also returns the number of scalable dimensions in
-/// `numScalableDims`.
 ///
 /// vector-dim-list := (static-dim-list `x`)?
 /// static-dim-list ::= static-dim (`x` static-dim)*
@@ -476,9 +472,7 @@ VectorType Parser::parseVectorType() {
 ///
 ParseResult
 Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
-                                 unsigned &numScalableDims,
                                  SmallVectorImpl<bool> &scalableDims) {
-  numScalableDims = 0;
   // If there is a set of fixed-length dimensions, consume it
   while (getToken().is(Token::integer) || getToken().is(Token::l_square)) {
     int64_t value;
@@ -489,7 +483,6 @@ Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
     if (scalable) {
       if (!consumeIf(Token::r_square))
         return emitWrongTokenError("missing ']' closing scalable dimension");
-      numScalableDims++;
     }
     scalableDims.push_back(scalable);
     // Make sure we have an 'x' or something like 'xbf32'.
index 0449ba9..4ca5c75 100644 (file)
@@ -463,11 +463,12 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
     return {};
   if (type.getShape().empty())
     return VectorType::get({1}, elementType);
-  Type vectorType =
-      VectorType::get(type.getShape().back(), elementType,
-                      type.getNumScalableDims(), type.getScalableDims().back());
+  Type vectorType = VectorType::get(type.getShape().back(), elementType,
+                                    type.getScalableDims().back());
   assert(LLVM::isCompatibleVectorType(vectorType) &&
          "expected vector type compatible with the LLVM dialect");
+  assert((type.isScalable() == type.allDimsScalable()) &&
+         "expected scalable vector with all dims scalable");
   auto shape = type.getShape();
   for (int i = shape.size() - 2; i >= 0; --i)
     vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
index 4175f8f..9901385 100644 (file)
@@ -31,21 +31,15 @@ using namespace mlir::vector;
 // Helper to reduce vector type by one rank at front.
 static VectorType reducedVectorTypeFront(VectorType tp) {
   assert((tp.getRank() > 1) && "unlowerable vector type");
-  unsigned numScalableDims = tp.getNumScalableDims();
-  if (tp.getShape().size() == numScalableDims)
-    --numScalableDims;
   return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
-                         numScalableDims);
+                         tp.getScalableDims().drop_front());
 }
 
 // Helper to reduce vector type by *all* but one rank at back.
 static VectorType reducedVectorTypeBack(VectorType tp) {
   assert((tp.getRank() > 1) && "unlowerable vector type");
-  unsigned numScalableDims = tp.getNumScalableDims();
-  if (numScalableDims > 0)
-    --numScalableDims;
   return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
-                         numScalableDims);
+                         tp.getScalableDims().take_back());
 }
 
 // Helper that picks the proper sequence for inserting.
index 633f296..5e95d16 100644 (file)
@@ -123,7 +123,6 @@ static Type getI1SameShape(Type type) {
     return UnrankedTensorType::get(i1Type);
   if (auto vectorType = llvm::dyn_cast<VectorType>(type))
     return VectorType::get(vectorType.getShape(), i1Type,
-                           vectorType.getNumScalableDims(),
                            vectorType.getScalableDims());
   return i1Type;
 }
index cdbf45b..4af836a 100644 (file)
@@ -30,7 +30,6 @@ static Type getI1SameShape(Type type) {
   auto i1Type = IntegerType::get(type.getContext(), 1);
   if (auto sVectorType = llvm::dyn_cast<VectorType>(type))
     return VectorType::get(sVectorType.getShape(), i1Type,
-                           sVectorType.getNumScalableDims(),
                            sVectorType.getScalableDims());
   return nullptr;
 }
index 1039bd2..bc8300a 100644 (file)
@@ -995,10 +995,7 @@ Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
 
   // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
   // scalable/non-scalable.
-  SmallVector<bool> scalableDims(1, isScalable);
-
-  return VectorType::get(numElements, elementType,
-                         static_cast<unsigned>(isScalable), scalableDims);
+  return VectorType::get(numElements, elementType, {isScalable});
 }
 
 Type mlir::LLVM::getVectorType(Type elementType,
@@ -1030,7 +1027,10 @@ Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
                                    "type");
   if (useLLVM)
     return LLVMScalableVectorType::get(elementType, numElements);
-  return VectorType::get(numElements, elementType, /*numScalableDims=*/1);
+
+  // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
+  // scalable/non-scalable.
+  return VectorType::get(numElements, elementType, /*scalableDims=*/true);
 }
 
 llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
index d0fcaad..a0bfd7f 100644 (file)
@@ -223,10 +223,7 @@ struct VectorizationState {
     assert(areValidScalableVecDims(scalableDims) &&
            "Permuted scalable vector dimensions are not supported");
 
-    // TODO: Extend scalable vector type to support a bit map.
-    bool numScalableDims = !scalableVecDims.empty() && scalableVecDims.back();
-    return VectorType::get(vectorShape, elementType, numScalableDims,
-                           scalableDims);
+    return VectorType::get(vectorShape, elementType, scalableDims);
   }
 
   /// Masks an operation with the canonical vector mask if the operation needs
@@ -1228,7 +1225,6 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
     if (firstMaxRankedType) {
       auto vecType = VectorType::get(firstMaxRankedType.getShape(),
                                      getElementTypeOrSelf(vecOperand.getType()),
-                                     firstMaxRankedType.getNumScalableDims(),
                                      firstMaxRankedType.getScalableDims());
       vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
     } else {
@@ -1241,7 +1237,6 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
     resultTypes.push_back(
         firstMaxRankedType
             ? VectorType::get(firstMaxRankedType.getShape(), resultType,
-                              firstMaxRankedType.getNumScalableDims(),
                               firstMaxRankedType.getScalableDims())
             : resultType);
   }
index 77bd330..93ee064 100644 (file)
@@ -56,9 +56,7 @@ static bool isInvariantArg(BlockArgument arg, Block *block) {
 
 /// Constructs vector type for element type.
 static VectorType vectorType(VL vl, Type etp) {
-  unsigned numScalableDims = vl.enableVLAVectorization;
-  return VectorType::get(vl.vectorLength, etp, numScalableDims,
-                         vl.enableVLAVectorization);
+  return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
 }
 
 /// Constructs vector type from a memref value.
index 7a39aa4..fc87c84 100644 (file)
@@ -1176,7 +1176,7 @@ Type Merger::inferType(ExprId e, Value src) const {
   // Inspect source type. For vector types, apply the same
   // vectorization to the destination type.
   if (auto vtp = dyn_cast<VectorType>(src.getType()))
-    return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
+    return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
   return dtp;
 }
 
index 7dd05f5..c2562af 100644 (file)
@@ -345,9 +345,9 @@ LogicalResult MultiDimReductionOp::verify() {
 /// Returns the mask type expected by this operation.
 Type MultiDimReductionOp::getExpectedMaskType() {
   auto vecType = getSourceVectorType();
-  return VectorType::get(
-      vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1),
-      vecType.getNumScalableDims(), vecType.getScalableDims());
+  return VectorType::get(vecType.getShape(),
+                         IntegerType::get(vecType.getContext(), /*width=*/1),
+                         vecType.getScalableDims());
 }
 
 namespace {
@@ -484,9 +484,9 @@ void ReductionOp::print(OpAsmPrinter &p) {
 /// Returns the mask type expected by this operation.
 Type ReductionOp::getExpectedMaskType() {
   auto vecType = getSourceVectorType();
-  return VectorType::get(
-      vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1),
-      vecType.getNumScalableDims(), vecType.getScalableDims());
+  return VectorType::get(vecType.getShape(),
+                         IntegerType::get(vecType.getContext(), /*width=*/1),
+                         vecType.getScalableDims());
 }
 
 Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
@@ -929,8 +929,7 @@ Type ContractionOp::getExpectedMaskType() {
   assert(!ShapedType::isDynamicShape(maskShape) &&
          "Mask shape couldn't be computed");
   // TODO: Extend the scalable vector type representation with a bit map.
-  assert(lhsType.getNumScalableDims() == 0 &&
-         rhsType.getNumScalableDims() == 0 &&
+  assert(!lhsType.isScalable() && !rhsType.isScalable() &&
          "Scalable vectors are not supported yet");
 
   return VectorType::get(maskShape,
@@ -2792,18 +2791,13 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
   if (vRHS) {
     SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
                                       vRHS.getScalableDims()[0]};
-    auto numScalableDims =
-        count_if(scalableDimsRes, [](bool isScalable) { return isScalable; });
     resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
-                              vLHS.getElementType(), numScalableDims,
-                              scalableDimsRes);
+                              vLHS.getElementType(), scalableDimsRes);
   } else {
     // Scalar RHS operand
     SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
-    auto numScalableDims =
-        count_if(scalableDimsRes, [](bool isScalable) { return isScalable; });
     resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
-                              numScalableDims, scalableDimsRes);
+                              scalableDimsRes);
   }
 
   if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
@@ -2867,9 +2861,9 @@ LogicalResult OuterProductOp::verify() {
 /// verification purposes. It requires the operation to be vectorized."
 Type OuterProductOp::getExpectedMaskType() {
   auto vecType = this->getResultVectorType();
-  return VectorType::get(
-      vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1),
-      vecType.getNumScalableDims(), vecType.getScalableDims());
+  return VectorType::get(vecType.getShape(),
+                         IntegerType::get(vecType.getContext(), /*width=*/1),
+                         vecType.getScalableDims());
 }
 
 //===----------------------------------------------------------------------===//
@@ -3528,8 +3522,7 @@ static VectorType inferTransferOpMaskType(VectorType vecType,
   SmallVector<bool> scalableDims =
       applyPermutationMap(invPermMap, vecType.getScalableDims());
 
-  return VectorType::get(maskShape, i1Type, vecType.getNumScalableDims(),
-                         scalableDims);
+  return VectorType::get(maskShape, i1Type, scalableDims);
 }
 
 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -4487,9 +4480,9 @@ LogicalResult GatherOp::verify() {
 /// verification purposes. It requires the operation to be vectorized."
 Type GatherOp::getExpectedMaskType() {
   auto vecType = this->getIndexVectorType();
-  return VectorType::get(
-      vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1),
-      vecType.getNumScalableDims(), vecType.getScalableDims());
+  return VectorType::get(vecType.getShape(),
+                         IntegerType::get(vecType.getContext(), /*width=*/1),
+                         vecType.getScalableDims());
 }
 
 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
index ea42d57..abe6d88 100644 (file)
@@ -1024,7 +1024,7 @@ public:
     Value mask = rewriter.create<vector::CreateMaskOp>(
         loc,
         VectorType::get(vtp.getShape(), rewriter.getI1Type(),
-                        vtp.getNumScalableDims()),
+                        vtp.getScalableDims()),
         b);
     if (xferOp.getMask()) {
       // Intersect the in-bounds with the mask specified as an op parameter.
index 62ef2c6..e29555f 100644 (file)
@@ -227,7 +227,6 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
 
 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
                                  ArrayRef<int64_t> shape, Type elementType,
-                                 unsigned numScalableDims,
                                  ArrayRef<bool> scalableDims) {
   if (!isValidElementType(elementType))
     return emitError()
@@ -239,21 +238,10 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
            << "vector types must have positive constant sizes but got "
            << shape;
 
-  if (numScalableDims > shape.size())
-    return emitError()
-           << "number of scalable dims cannot exceed the number of dims"
-           << " (" << numScalableDims << " vs " << shape.size() << ")";
-
   if (scalableDims.size() != shape.size())
     return emitError() << "number of dims must match, got "
                        << scalableDims.size() << " and " << shape.size();
 
-  auto numScale =
-      count_if(scalableDims, [](bool isScalable) { return isScalable; });
-  if (numScale != numScalableDims)
-    return emitError() << "number of scalable dims must match, explicit: "
-                       << numScalableDims << ", and bools:" << numScale;
-
   return success();
 }
 
@@ -262,17 +250,17 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
     return VectorType();
   if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
     if (auto scaledEt = et.scaleElementBitwidth(scale))
-      return VectorType::get(getShape(), scaledEt, getNumScalableDims());
+      return VectorType::get(getShape(), scaledEt, getScalableDims());
   if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
     if (auto scaledEt = et.scaleElementBitwidth(scale))
-      return VectorType::get(getShape(), scaledEt, getNumScalableDims());
+      return VectorType::get(getShape(), scaledEt, getScalableDims());
   return VectorType();
 }
 
 VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
                                  Type elementType) const {
   return VectorType::get(shape.value_or(getShape()), elementType,
-                         getNumScalableDims());
+                         getScalableDims());
 }
 
 //===----------------------------------------------------------------------===//