[mlir][RFC] Add scalable dimensions to VectorType
authorJavier Setoain <javier.setoain@gmail.com>
Tue, 12 Oct 2021 13:26:01 +0000 (14:26 +0100)
committerJavier Setoain <javier.setoain@gmail.com>
Wed, 15 Dec 2021 09:31:37 +0000 (09:31 +0000)
With VectorType supporting scalable dimensions, we don't need many of
the operations currently present in ArmSVE, like mask generation and
basic arithmetic instructions. Therefore, this patch also gets
rid of those.

Having built-in scalable vector support also simplifies the lowering of
scalable vector dialects down to LLVMIR.

Scalable dimensions are indicated with the scalable dimensions
between square brackets:

        vector<[4]xf32>

Is a scalable vector of 4 single precission floating point elements.

More generally, a VectorType can have a set of fixed-length dimensions
followed by a set of scalable dimensions:

        vector<2x[4x4]xf32>

Is a vector with 2 scalable 4x4 vectors of single precission floating
point elements.

The scale of the scalable dimensions can be obtained with the Vector
operation:

        %vs = vector.vscale

This change is being discussed in the discourse RFC:

https://llvm.discourse.group/t/rfc-add-built-in-support-for-scalable-vector-types/4484

Differential Revision: https://reviews.llvm.org/D111819

33 files changed:
mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h
mlir/include/mlir/Dialect/ArmSVE/ArmSVEOpBase.td [deleted file]
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/BuiltinTypes.td
mlir/include/mlir/IR/OpBase.td
mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/Parser/Parser.h
mlir/lib/Parser/TypeParser.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
mlir/test/Dialect/Arithmetic/ops.mlir
mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
mlir/test/Dialect/ArmSVE/memcpy.mlir [deleted file]
mlir/test/Dialect/ArmSVE/roundtrip.mlir
mlir/test/Dialect/Builtin/invalid.mlir
mlir/test/Dialect/Builtin/ops.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Dialect/Vector/vector-scalable-memcpy.mlir [new file with mode: 0644]
mlir/test/Target/LLVMIR/arm-sve.mlir

index b0cfd58..5ffcf74 100644 (file)
@@ -16,7 +16,6 @@
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
-include "mlir/Dialect/ArmSVE/ArmSVEOpBase.td"
 
 //===----------------------------------------------------------------------===//
 // ArmSVE dialect definition
@@ -27,86 +26,18 @@ def ArmSVE_Dialect : Dialect {
   let cppNamespace = "::mlir::arm_sve";
   let summary = "Basic dialect to target Arm SVE architectures";
   let description = [{
-    This dialect contains the definitions necessary to target Arm SVE scalable
-    vector operations, including a scalable vector type and intrinsics for
-    some Arm SVE instructions.
+    This dialect contains the definitions necessary to target specific Arm SVE
+    scalable vector operations.
   }];
-  let useDefaultTypePrinterParser = 1;
 }
 
 //===----------------------------------------------------------------------===//
-// ArmSVE type definitions
-//===----------------------------------------------------------------------===//
-
-def ArmSVE_ScalableVectorType : DialectType<ArmSVE_Dialect,
-    CPred<"$_self.isa<ScalableVectorType>()">,
-              "scalable vector type">,
-    BuildableType<"$_builder.getType<ScalableVectorType>()"> {
-  let description = [{
-    `arm_sve.vector` represents vectors that will be processed by a scalable
-    vector architecture.
-  }];
-}
-
-class ArmSVE_Type<string name> : TypeDef<ArmSVE_Dialect, name> { }
-
-def ScalableVectorType : ArmSVE_Type<"ScalableVector"> {
-  let mnemonic = "vector";
-
-  let summary = "Scalable vector type";
-
-  let description = [{
-    A type representing scalable length SIMD vectors. Unlike fixed-length SIMD
-    vectors, whose size is constant and known at compile time, scalable
-    vectors' length is constant but determined by the specific hardware at
-    run time.
-  }];
-
-  let parameters = (ins
-    ArrayRefParameter<"int64_t", "Vector shape">:$shape,
-    "Type":$elementType
-  );
-
-  let extraClassDeclaration = [{
-    bool hasStaticShape() const {
-      return llvm::none_of(getShape(), ShapedType::isDynamic);
-    }
-    int64_t getNumElements() const {
-      assert(hasStaticShape() &&
-             "cannot get element count of dynamic shaped type");
-      ArrayRef<int64_t> shape = getShape();
-      int64_t num = 1;
-      for (auto dim : shape)
-        num *= dim;
-      return num;
-    }
-  }];
-}
-
-//===----------------------------------------------------------------------===//
-// Additional LLVM type constraints
-//===----------------------------------------------------------------------===//
-def LLVMScalableVectorType :
-  Type<CPred<"$_self.isa<::mlir::LLVM::LLVMScalableVectorType>()">,
-       "LLVM dialect scalable vector type">;
-
-//===----------------------------------------------------------------------===//
 // ArmSVE op definitions
 //===----------------------------------------------------------------------===//
 
 class ArmSVE_Op<string mnemonic, list<OpTrait> traits = []> :
   Op<ArmSVE_Dialect, mnemonic, traits> {}
 
-class ArmSVE_NonSVEIntrUnaryOverloadedOp<string mnemonic,
-                                         list<OpTrait> traits =[]> :
-  LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
-                  /*string opName=*/mnemonic,
-                  /*string enumName=*/mnemonic,
-                  /*list<int> overloadedResults=*/[0],
-                  /*list<int> overloadedOperands=*/[], // defined by result overload
-                  /*list<OpTrait> traits=*/traits,
-                  /*int numResults=*/1>;
-
 class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
                                     list<OpTrait> traits = []> :
   LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
@@ -117,42 +48,6 @@ class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
                   /*list<OpTrait> traits=*/traits,
                   /*int numResults=*/1>;
 
-class ScalableFOp<string mnemonic, string op_description,
-                  list<OpTrait> traits = []> :
-  ArmSVE_Op<mnemonic, !listconcat(traits,
-                       [AllTypesMatch<["src1", "src2", "dst"]>])> {
-  let summary = op_description # " for scalable vectors of floats";
-  let description = [{
-    The `arm_sve.}] # mnemonic # [{` operations takes two scalable vectors and
-    returns one scalable vector with the result of the }] # op_description # [{.
-  }];
-  let arguments = (ins
-          ScalableVectorOf<[AnyFloat]>:$src1,
-          ScalableVectorOf<[AnyFloat]>:$src2
-  );
-  let results = (outs ScalableVectorOf<[AnyFloat]>:$dst);
-  let assemblyFormat =
-    "$src1 `,` $src2 attr-dict `:` type($src1)";
-}
-
-class ScalableIOp<string mnemonic, string op_description,
-                  list<OpTrait> traits = []> :
-  ArmSVE_Op<mnemonic, !listconcat(traits,
-                       [AllTypesMatch<["src1", "src2", "dst"]>])> {
-  let summary = op_description # " for scalable vectors of integers";
-  let description = [{
-    The `arm_sve.}] # mnemonic # [{` operation takes two scalable vectors and
-    returns one scalable vector with the result of the }] # op_description # [{.
-  }];
-  let arguments = (ins
-          ScalableVectorOf<[I8, I16, I32, I64]>:$src1,
-          ScalableVectorOf<[I8, I16, I32, I64]>:$src2
-  );
-  let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$dst);
-  let assemblyFormat =
-    "$src1 `,` $src2 attr-dict `:` type($src1)";
-}
-
 class ScalableMaskedFOp<string mnemonic, string op_description,
                         list<OpTrait> traits = []> :
   ArmSVE_Op<mnemonic, !listconcat(traits,
@@ -325,74 +220,6 @@ def UmmlaOp : ArmSVE_Op<"ummla",
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
 }
 
-def VectorScaleOp : ArmSVE_Op<"vector_scale",
-                 [NoSideEffect]> {
-  let summary = "Load vector scale size";
-  let description = [{
-    The vector_scale op returns the scale of the scalable vectors, a positive
-    integer value that is constant at runtime but unknown at compile time.
-    The scale of the vector indicates the multiplicity of the vectors and
-    vector operations. I.e.: an !arm_sve.vector<4xi32> is equivalent to
-    vector_scale consecutive vector<4xi32>; and an operation on an
-    !arm_sve.vector<4xi32> is equivalent to performing that operation vector_scale
-    times, once on each <4xi32> segment of the scalable vector. The vector_scale
-    op can be used to calculate the step in vector-length agnostic (VLA) loops.
-  }];
-  let results = (outs Index:$res);
-  let assemblyFormat =
-    "attr-dict `:` type($res)";
-}
-
-def ScalableLoadOp : ArmSVE_Op<"load">,
-    Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base, Index:$index)>,
-    Results<(outs ScalableVectorOf<[AnyType]>:$result)> {
-  let summary = "Load scalable vector from memory";
-  let description = [{
-    Load a slice of memory into a scalable vector.
-  }];
-  let extraClassDeclaration = [{
-    MemRefType getMemRefType() {
-      return base().getType().cast<MemRefType>();
-    }
-  }];
-  let assemblyFormat = "$base `[` $index `]` attr-dict `:` "
-    "type($result) `from` type($base)";
-}
-
-def ScalableStoreOp : ArmSVE_Op<"store">,
-    Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base, Index:$index,
-               ScalableVectorOf<[AnyType]>:$value)> {
-  let summary = "Store scalable vector into memory";
-  let description = [{
-    Store a scalable vector on a slice of memory.
-  }];
-  let extraClassDeclaration = [{
-    MemRefType getMemRefType() {
-      return base().getType().cast<MemRefType>();
-    }
-  }];
-  let assemblyFormat = "$value `,` $base `[` $index `]` attr-dict `:` "
-    "type($value) `to` type($base)";
-}
-
-def ScalableAddIOp : ScalableIOp<"addi", "addition", [Commutative]>;
-
-def ScalableAddFOp : ScalableFOp<"addf", "addition", [Commutative]>;
-
-def ScalableSubIOp : ScalableIOp<"subi", "subtraction">;
-
-def ScalableSubFOp : ScalableFOp<"subf", "subtraction">;
-
-def ScalableMulIOp : ScalableIOp<"muli", "multiplication", [Commutative]>;
-
-def ScalableMulFOp : ScalableFOp<"mulf", "multiplication", [Commutative]>;
-
-def ScalableSDivIOp : ScalableIOp<"divi_signed", "signed division">;
-
-def ScalableUDivIOp : ScalableIOp<"divi_unsigned", "unsigned division">;
-
-def ScalableDivFOp : ScalableFOp<"divf", "division">;
-
 def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
                                              [Commutative]>;
 
@@ -417,189 +244,56 @@ def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned",
 
 def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
 
-//===----------------------------------------------------------------------===//
-// ScalableCmpFOp
-//===----------------------------------------------------------------------===//
-
-def ScalableCmpFOp : ArmSVE_Op<"cmpf", [NoSideEffect, SameTypeOperands,
-    TypesMatchWith<"result type has i1 element type and same shape as operands",
-    "lhs", "result", "getI1SameShape($_self)">]> {
-  let summary = "floating-point comparison operation for scalable vectors";
-  let description = [{
-    The `arm_sve.cmpf` operation compares two scalable vectors of floating point
-    elements according to the float comparison rules and the predicate specified
-    by the respective attribute. The predicate defines the type of comparison:
-    (un)orderedness, (in)equality and signed less/greater than (or equal to) as
-    well as predicates that are always true or false. The result is a scalable
-    vector of i1 elements. Unlike `arm_sve.cmpi`, the operands are always
-    treated as signed. The u prefix indicates *unordered* comparison, not
-    unsigned comparison, so "une" means unordered not equal. For the sake of
-    readability by humans, custom assembly form for the operation uses a
-    string-typed attribute for the predicate.  The value of this attribute
-    corresponds to lower-cased name of the predicate constant, e.g., "one" means
-    "ordered not equal".  The string representation of the attribute is merely a
-    syntactic sugar and is converted to an integer attribute by the parser.
-
-    Example:
-
-    ```mlir
-    %r = arm_sve.cmpf oeq, %0, %1 : !arm_sve.vector<4xf32>
-    ```
-  }];
-  let arguments = (ins
-    Arith_CmpFPredicateAttr:$predicate,
-    ScalableVectorOf<[AnyFloat]>:$lhs,
-    ScalableVectorOf<[AnyFloat]>:$rhs // TODO: This should support a simple scalar
-  );
-  let results = (outs ScalableVectorOf<[I1]>:$result);
-
-  let builders = [
-    OpBuilder<(ins "arith::CmpFPredicate":$predicate, "Value":$lhs,
-                  "Value":$rhs), [{
-      buildScalableCmpFOp($_builder, $_state, predicate, lhs, rhs);
-    }]>];
-
-  let extraClassDeclaration = [{
-    static StringRef getPredicateAttrName() { return "predicate"; }
-    static arith::CmpFPredicate getPredicateByName(StringRef name);
-
-    arith::CmpFPredicate getPredicate() {
-      return (arith::CmpFPredicate) (*this)->getAttrOfType<IntegerAttr>(
-          getPredicateAttrName()).getInt();
-    }
-  }];
-
-  let verifier = [{ return success(); }];
-
-  let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
-}
-
-//===----------------------------------------------------------------------===//
-// ScalableCmpIOp
-//===----------------------------------------------------------------------===//
-
-def ScalableCmpIOp : ArmSVE_Op<"cmpi", [NoSideEffect, SameTypeOperands,
-    TypesMatchWith<"result type has i1 element type and same shape as operands",
-    "lhs", "result", "getI1SameShape($_self)">]> {
-  let summary = "integer comparison operation for scalable vectors";
-  let description = [{
-    The `arm_sve.cmpi` operation compares two scalable vectors of integer
-    elements according to the predicate specified by the respective attribute.
-
-    The predicate defines the type of comparison:
-
-    -   equal (mnemonic: `"eq"`; integer value: `0`)
-    -   not equal (mnemonic: `"ne"`; integer value: `1`)
-    -   signed less than (mnemonic: `"slt"`; integer value: `2`)
-    -   signed less than or equal (mnemonic: `"sle"`; integer value: `3`)
-    -   signed greater than (mnemonic: `"sgt"`; integer value: `4`)
-    -   signed greater than or equal (mnemonic: `"sge"`; integer value: `5`)
-    -   unsigned less than (mnemonic: `"ult"`; integer value: `6`)
-    -   unsigned less than or equal (mnemonic: `"ule"`; integer value: `7`)
-    -   unsigned greater than (mnemonic: `"ugt"`; integer value: `8`)
-    -   unsigned greater than or equal (mnemonic: `"uge"`; integer value: `9`)
-
-    Example:
-
-    ```mlir
-    %r = arm_sve.cmpi uge, %0, %1 : !arm_sve.vector<4xi32>
-    ```
-  }];
-
-  let arguments = (ins
-      Arith_CmpIPredicateAttr:$predicate,
-      ScalableVectorOf<[I8, I16, I32, I64]>:$lhs,
-      ScalableVectorOf<[I8, I16, I32, I64]>:$rhs
-  );
-  let results = (outs ScalableVectorOf<[I1]>:$result);
-
-  let builders = [
-    OpBuilder<(ins "arith::CmpIPredicate":$predicate, "Value":$lhs,
-                 "Value":$rhs), [{
-      buildScalableCmpIOp($_builder, $_state, predicate, lhs, rhs);
-    }]>];
-
-  let extraClassDeclaration = [{
-    static StringRef getPredicateAttrName() { return "predicate"; }
-    static arith::CmpIPredicate getPredicateByName(StringRef name);
-
-    arith::CmpIPredicate getPredicate() {
-      return (arith::CmpIPredicate) (*this)->getAttrOfType<IntegerAttr>(
-          getPredicateAttrName()).getInt();
-    }
-  }];
-
-  let verifier = [{ return success(); }];
-
-  let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
-}
-
 def UmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"ummla">,
-  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
-             LLVMScalableVectorType)>;
+  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
 
 def SmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"smmla">,
-  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
-             LLVMScalableVectorType)>;
+  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
 
 def SdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdot">,
-  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
-             LLVMScalableVectorType)>;
+  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
 
 def UdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"udot">,
-  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
-             LLVMScalableVectorType)>;
+  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
 
 def ScalableMaskedAddIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"add">,
-  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
-             LLVMScalableVectorType)>;
+  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
 
 def ScalableMaskedAddFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fadd">,
-  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
-             LLVMScalableVectorType)>;
+  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
 
 def ScalableMaskedMulIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"mul">,
-  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
-             LLVMScalableVectorType)>;
+  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
 
 def ScalableMaskedMulFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fmul">,
-  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
-             LLVMScalableVectorType)>;
+  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
 
 def ScalableMaskedSubIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sub">,
-  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
-             LLVMScalableVectorType)>;
+  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
 
 def ScalableMaskedSubFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fsub">,
-  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
-             LLVMScalableVectorType)>;
+  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
 
 def ScalableMaskedSDivIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdiv">,
-  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
-             LLVMScalableVectorType)>;
+  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
 
 def ScalableMaskedUDivIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"udiv">,
-  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
-             LLVMScalableVectorType)>;
+  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
 
 def ScalableMaskedDivFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
-  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
-             LLVMScalableVectorType)>;
-
-def VectorScaleIntrOp:
-  ArmSVE_NonSVEIntrUnaryOverloadedOp<"vscale">;
+  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
 
 #endif // ARMSVE_OPS
index 06fbf5a..47686dd 100644 (file)
@@ -21,9 +21,6 @@
 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h.inc"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 
-#define GET_TYPEDEF_CLASSES
-#include "mlir/Dialect/ArmSVE/ArmSVETypes.h.inc"
-
 #define GET_OP_CLASSES
 #include "mlir/Dialect/ArmSVE/ArmSVE.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEOpBase.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEOpBase.td
deleted file mode 100644 (file)
index ee40924..0000000
+++ /dev/null
@@ -1,53 +0,0 @@
-//===-- ArmSVEOpBase.td - Base op definitions for ArmSVE ---*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This is the base operation definition file for ArmSVE scalable vector types.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef ARMSVE_OP_BASE
-#define ARMSVE_OP_BASE
-
-//===----------------------------------------------------------------------===//
-// ArmSVE scalable vector type constraints
-//===----------------------------------------------------------------------===//
-
-def IsScalableVectorTypePred :
-    CPred<"$_self.isa<::mlir::arm_sve::ScalableVectorType>()">;
-
-class ScalableVectorOf<list<Type> allowedTypes> :
-    ContainerType<AnyTypeOf<allowedTypes>, IsScalableVectorTypePred,
-          "$_self.cast<::mlir::arm_sve::ScalableVectorType>().getElementType()",
-          "scalable vector">;
-
-// Whether the number of elements of a scalable vector is from the given
-// `allowedLengths` list
-class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
-  And<[IsScalableVectorTypePred,
-       Or<!foreach(allowedlength, allowedLengths, CPred<
-          [{$_self.cast<::mlir::arm_sve::ScalableVectorType>().getNumElements() == }]
-          # allowedlength>)>]>;
-
-// Any scalable vector where the number of elements is from the given
-// `allowedLengths` list
-class ScalableVectorOfLength<list<int> allowedLengths> : Type<
-  IsScalableVectorOfLengthPred<allowedLengths>,
-  " of length " # !interleave(allowedLengths, "/"),
-  "::mlir::arm_sve::ScalableVectorType">;
-
-// Any scalable vector where the number of elements is from the given
-// `allowedLengths` list and the type is from the given `allowedTypes` list
-class ScalableVectorOfLengthAndType<list<int> allowedLengths,
-                                    list<Type> allowedTypes> : Type<
-  And<[ScalableVectorOf<allowedTypes>.predicate,
-       ScalableVectorOfLength<allowedLengths>.predicate]>,
-  ScalableVectorOf<allowedTypes>.summary #
-  ScalableVectorOfLength<allowedLengths>.summary,
-  "::mlir::arm_sve::ScalableVectorType">;
-
-#endif // ARMSVE_OP_BASE
index 453ac8d..f07f0df 100644 (file)
@@ -1734,7 +1734,9 @@ def LLVM_masked_compressstore
   let arguments = (ins LLVM_Type, LLVM_Type, LLVM_Type);
 }
 
-//
+/// Create a call to vscale intrinsic.
+def LLVM_vscale : LLVM_IntrOp<"vscale", [0], [], [], 1>;
+
 // Atomic operations.
 //
 
index 3d6a6fd..452955a 100644 (file)
@@ -483,10 +483,22 @@ Type getVectorElementType(Type type);
 /// Returns the element count of any LLVM-compatible vector type.
 llvm::ElementCount getVectorNumElements(Type type);
 
+/// Returns whether a vector type is scalable or not.
+bool isScalableVectorType(Type vectorType);
+
+/// Creates an LLVM dialect-compatible vector type with the given element type
+/// and length.
+Type getVectorType(Type elementType, unsigned numElements,
+                   bool isScalable = false);
+
 /// Creates an LLVM dialect-compatible type with the given element type and
 /// length.
 Type getFixedVectorType(Type elementType, unsigned numElements);
 
+/// Creates an LLVM dialect-compatible type with the given element type and
+/// length.
+Type getScalableVectorType(Type elementType, unsigned numElements);
+
 /// Returns the size of the given primitive LLVM dialect-compatible type
 /// (including vectors) in bits, for example, the size of i16 is 16 and
 /// the size of vector<4xi16> is 64. Returns 0 for non-primitive
index a05a377..32f3bb9 100644 (file)
@@ -2383,4 +2383,36 @@ def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [NoSideEffect,
   let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)";
 }
 
+//===----------------------------------------------------------------------===//
+// VectorScaleOp
+//===----------------------------------------------------------------------===//
+
+// TODO: In the future, we might want to have scalable vectors with different
+//       scales for different dimensions. E.g.: vector<[16]x[16]xf32>, in
+//       which case we might need to add an index to 'vscale' to select one
+//       of them. In order to support GPUs, we might also want to differentiate
+//       between a 'global' scale, a scale that's fixed throughout the
+//       execution, and a 'local' scale that is fixed but might vary with each
+//       call to the function. For that, it might be useful to have a
+//       'vector.scale.global' and a 'vector.scale.local' operation.
+def VectorScaleOp : Vector_Op<"vscale",
+                 [NoSideEffect]> {
+  let summary = "Load vector scale size";
+  let description = [{
+    The `vscale` op returns the scale of the scalable vectors, a positive
+    integer value that is constant at runtime but unknown at compile-time.
+    The scale of the vector indicates the multiplicity of the vectors and
+    vector operations. For example, a `vector<[4]xi32>` is equivalent to
+    `vscale` consecutive `vector<4xi32>`; and an operation on a
+    `vector<[4]xi32>` is equivalent to performing that operation `vscale`
+    times, once on each `<4xi32>` segment of the scalable vector. The `vscale`
+    op can be used to calculate the step in vector-length agnostic (VLA) loops.
+    Right now we only support one contiguous set of scalable dimensions, all of
+    them grouped and scaled with the value returned by 'vscale'.
+  }];
+  let results = (outs Index:$res);
+  let assemblyFormat = "attr-dict";
+  let verifier = ?;
+}
+
 #endif // VECTOR_OPS
index be98b57..6f25f81 100644 (file)
@@ -315,13 +315,18 @@ class VectorType::Builder {
 public:
   /// Build from another VectorType.
   explicit Builder(VectorType other)
-      : shape(other.getShape()), elementType(other.getElementType()) {}
+      : shape(other.getShape()), elementType(other.getElementType()),
+        numScalableDims(other.getNumScalableDims()) {}
 
   /// Build from scratch.
-  Builder(ArrayRef<int64_t> shape, Type elementType)
-      : shape(shape), elementType(elementType) {}
-
-  Builder &setShape(ArrayRef<int64_t> newShape) {
+  Builder(ArrayRef<int64_t> shape, Type elementType,
+          unsigned numScalableDims = 0)
+      : shape(shape), elementType(elementType),
+        numScalableDims(numScalableDims) {}
+
+  Builder &setShape(ArrayRef<int64_t> newShape,
+                    unsigned newNumScalableDims = 0) {
+    numScalableDims = newNumScalableDims;
     shape = newShape;
     return *this;
   }
@@ -334,6 +339,8 @@ public:
   /// Erase a dim from shape @pos.
   Builder &dropDim(unsigned pos) {
     assert(pos < shape.size() && "overflow");
+    if (pos >= shape.size() - numScalableDims)
+      numScalableDims--;
     if (storage.empty())
       storage.append(shape.begin(), shape.end());
     storage.erase(storage.begin() + pos);
@@ -347,7 +354,7 @@ public:
   operator Type() {
     if (shape.empty())
       return elementType;
-    return VectorType::get(shape, elementType);
+    return VectorType::get(shape, elementType, numScalableDims);
   }
 
 private:
@@ -355,6 +362,7 @@ private:
   // Owning shape data for copy-on-write operations.
   SmallVector<int64_t> storage;
   Type elementType;
+  unsigned numScalableDims;
 };
 
 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
index af47f1d..17e2de1 100644 (file)
@@ -892,16 +892,21 @@ def Builtin_Vector : Builtin_Type<"Vector", [
     Syntax:
 
     ```
-    vector-type ::= `vector` `<` static-dimension-list vector-element-type `>`
+    vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
     vector-element-type ::= float-type | integer-type | index-type
-
-    static-dimension-list ::= (decimal-literal `x`)*
+    vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
+    static-dim-list ::= decimal-literal (`x` decimal-literal)*
     ```
 
-    The vector type represents a SIMD style vector, used by target-specific
-    operation sets like AVX. While the most common use is for 1D vectors (e.g.
-    vector<16 x f32>) we also support multidimensional registers on targets that
-    support them (like TPUs).
+    The vector type represents a SIMD style vector used by target-specific
+    operation sets like AVX or SVE. While the most common use is for 1D
+    vectors (e.g. vector<16 x f32>) we also support multidimensional registers
+    on targets that support them (like TPUs). The dimensions of a vector type
+    can be fixed-length, scalable, or a combination of the two. The scalable
+    dimensions in a vector are indicated between square brackets ([ ]), and
+    all fixed-length dimensions, if present, must precede the set of scalable
+    dimensions. That is, a `vector<2x[4]xf32>` is valid, but `vector<[4]x2xf32>`
+    is not.
 
     Vector shapes must be positive decimal integers. 0D vectors are allowed by
     omitting the dimension: `vector<f32>`.
@@ -913,19 +918,31 @@ def Builtin_Vector : Builtin_Type<"Vector", [
     Examples:
 
     ```mlir
+    // A 2D fixed-length vector of 3x42 i32 elements.
     vector<3x42xi32>
+
+    // A 1D scalable-length vector that contains a multiple of 4 f32 elements.
+    vector<[4]xf32>
+
+    // A 2D scalable-length vector that contains a multiple of 2x8 i8 elements.
+    vector<[2x8]xf32>
+
+    // A 2D mixed fixed/scalable vector that contains 4 scalable vectors of 4 f32 elements.
+    vector<4x[4]xf32>
     ```
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
-    "Type":$elementType
+    "Type":$elementType,
+    "unsigned":$numScalableDims
   );
-
   let builders = [
     TypeBuilderWithInferredContext<(ins
-      "ArrayRef<int64_t>":$shape, "Type":$elementType
+      "ArrayRef<int64_t>":$shape, "Type":$elementType,
+      CArg<"unsigned", "0">:$numScalableDims
     ), [{
-      return $_get(elementType.getContext(), shape, elementType);
+      return $_get(elementType.getContext(), shape, elementType,
+                   numScalableDims);
     }]>
   ];
   let extraClassDeclaration = [{
@@ -933,13 +950,18 @@ def Builtin_Vector : Builtin_Type<"Vector", [
     /// Arguments that are passed into the builder must outlive the builder.
     class Builder;
 
-    /// Returns true of the given type can be used as an element of a vector
+    /// Returns true if the given type can be used as an element of a vector
     /// type. In particular, vectors can consist of integer, index, or float
     /// primitives.
     static bool isValidElementType(Type t) {
       return t.isa<IntegerType, IndexType, FloatType>();
     }
 
+    /// Returns true if the vector contains scalable dimensions.
+    bool isScalable() const {
+      return getNumScalableDims() > 0;
+    }
+
     /// Get or create a new VectorType with the same shape as `this` and an
     /// element type of bitwidth scaled by `scale`.
     /// Return null if the scaled element type cannot be represented.
index 22f4045..f1a5446 100644 (file)
@@ -216,6 +216,14 @@ def IsVectorTypePred : And<[CPred<"$_self.isa<::mlir::VectorType>()">,
 // TODO: Remove this when all ops support 0-D vectors.
 def IsVectorOfAnyRankTypePred : CPred<"$_self.isa<::mlir::VectorType>()">;
 
+// Whether a type is a fixed-length VectorType.
+def IsFixedVectorTypePred : CPred<[{$_self.isa<::mlir::VectorType>() &&
+                                  !$_self.cast<VectorType>().isScalable()}]>;
+
+// Whether a type is a scalable VectorType.
+def IsScalableVectorTypePred : CPred<[{$_self.isa<::mlir::VectorType>() &&
+                                   $_self.cast<VectorType>().isScalable()}]>;
+
 // Whether a type is a TensorType.
 def IsTensorTypePred : CPred<"$_self.isa<::mlir::TensorType>()">;
 
@@ -611,6 +619,14 @@ class VectorOfAnyRankOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
                       "::mlir::VectorType">;
 
+class FixedVectorOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsFixedVectorTypePred,
+          "fixed-length vector", "::mlir::VectorType">;
+
+class ScalableVectorOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsScalableVectorTypePred,
+          "scalable vector", "::mlir::VectorType">;
+
 // Whether the number of elements of a vector is from the given
 // `allowedRanks` list
 class IsVectorOfRankPred<list<int> allowedRanks> :
@@ -643,6 +659,24 @@ class IsVectorOfLengthPred<list<int> allowedLengths> :
                            == }]
                          # allowedlength>)>]>;
 
+// Whether the number of elements of a fixed-length vector is from the given
+// `allowedLengths` list
+class IsFixedVectorOfLengthPred<list<int> allowedLengths> :
+  And<[IsFixedVectorTypePred,
+       Or<!foreach(allowedlength, allowedLengths,
+                   CPred<[{$_self.cast<::mlir::VectorType>().getNumElements()
+                           == }]
+                         # allowedlength>)>]>;
+
+// Whether the number of elements of a scalable vector is from the given
+// `allowedLengths` list
+class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
+  And<[IsScalableVectorTypePred,
+       Or<!foreach(allowedlength, allowedLengths,
+                   CPred<[{$_self.cast<::mlir::VectorType>().getNumElements()
+                           == }]
+                         # allowedlength>)>]>;
+
 // Any vector where the number of elements is from the given
 // `allowedLengths` list
 class VectorOfLength<list<int> allowedLengths> : Type<
@@ -650,6 +684,20 @@ class VectorOfLength<list<int> allowedLengths> : Type<
   " of length " # !interleave(allowedLengths, "/"),
   "::mlir::VectorType">;
 
+// Any fixed-length vector where the number of elements is from the given
+// `allowedLengths` list
+class FixedVectorOfLength<list<int> allowedLengths> : Type<
+  IsFixedVectorOfLengthPred<allowedLengths>,
+  " of length " # !interleave(allowedLengths, "/"),
+  "::mlir::VectorType">;
+
+// Any scalable vector where the number of elements is from the given
+// `allowedLengths` list
+class ScalableVectorOfLength<list<int> allowedLengths> : Type<
+  IsScalableVectorOfLengthPred<allowedLengths>,
+  " of length " # !interleave(allowedLengths, "/"),
+  "::mlir::VectorType">;
+
 // Any vector where the number of elements is from the given
 // `allowedLengths` list and the type is from the given `allowedTypes`
 // list
@@ -660,10 +708,34 @@ class VectorOfLengthAndType<list<int> allowedLengths,
   VectorOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
+// Any fixed-length vector where the number of elements is from the given
+// `allowedLengths` list and the type is from the given `allowedTypes` list
+class FixedVectorOfLengthAndType<list<int> allowedLengths,
+                                    list<Type> allowedTypes> : Type<
+  And<[FixedVectorOf<allowedTypes>.predicate,
+       FixedVectorOfLength<allowedLengths>.predicate]>,
+  FixedVectorOf<allowedTypes>.summary #
+  FixedVectorOfLength<allowedLengths>.summary,
+  "::mlir::VectorType">;
+
+// Any scalable vector where the number of elements is from the given
+// `allowedLengths` list and the type is from the given `allowedTypes` list
+class ScalableVectorOfLengthAndType<list<int> allowedLengths,
+                                    list<Type> allowedTypes> : Type<
+  And<[ScalableVectorOf<allowedTypes>.predicate,
+       ScalableVectorOfLength<allowedLengths>.predicate]>,
+  ScalableVectorOf<allowedTypes>.summary #
+  ScalableVectorOfLength<allowedLengths>.summary,
+  "::mlir::VectorType">;
+
 def AnyVector : VectorOf<[AnyType]>;
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
 
+def AnyFixedVector : FixedVectorOf<[AnyType]>;
+
+def AnyScalableVector : ScalableVectorOf<[AnyType]>;
+
 // Shaped types.
 
 def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",
index 5175b93..3b4e7ad 100644 (file)
@@ -411,7 +411,8 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
     return {};
   if (type.getShape().empty())
     return VectorType::get({1}, elementType);
-  Type vectorType = VectorType::get(type.getShape().back(), elementType);
+  Type vectorType = VectorType::get(type.getShape().back(), elementType,
+                                    type.getNumScalableDims());
   assert(LLVM::isCompatibleVectorType(vectorType) &&
          "expected vector type compatible with the LLVM dialect");
   auto shape = type.getShape();
index 9b4dce4..062a544 100644 (file)
@@ -26,13 +26,21 @@ using namespace mlir::vector;
 // Helper to reduce vector type by one rank at front.
 static VectorType reducedVectorTypeFront(VectorType tp) {
   assert((tp.getRank() > 1) && "unlowerable vector type");
-  return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
+  unsigned numScalableDims = tp.getNumScalableDims();
+  if (tp.getShape().size() == numScalableDims)
+    --numScalableDims;
+  return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
+                         numScalableDims);
 }
 
 // Helper to reduce vector type by *all* but one rank at back.
 static VectorType reducedVectorTypeBack(VectorType tp) {
   assert((tp.getRank() > 1) && "unlowerable vector type");
-  return VectorType::get(tp.getShape().take_back(), tp.getElementType());
+  unsigned numScalableDims = tp.getNumScalableDims();
+  if (numScalableDims > 0)
+    --numScalableDims;
+  return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
+                         numScalableDims);
 }
 
 // Helper that picks the proper sequence for inserting.
@@ -112,6 +120,10 @@ static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
 
 namespace {
 
+/// Trivial Vector to LLVM conversions
+using VectorScaleOpConversion =
+    OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
+
 /// Conversion pattern for a vector.bitcast.
 class VectorBitCastOpConversion
     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
@@ -1064,7 +1076,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
            VectorExtractElementOpConversion, VectorExtractOpConversion,
            VectorFMAOp1DConversion, VectorInsertElementOpConversion,
            VectorInsertOpConversion, VectorPrintOpConversion,
-           VectorTypeCastOpConversion,
+           VectorTypeCastOpConversion, VectorScaleOpConversion,
            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
            VectorLoadStoreConversion<vector::MaskedLoadOp,
                                      vector::MaskedLoadOpAdaptor>,
index dc4809a..aca6e4f 100644 (file)
@@ -999,7 +999,8 @@ static Type getI1SameShape(Type type) {
   if (type.isa<UnrankedTensorType>())
     return UnrankedTensorType::get(i1Type);
   if (auto vectorType = type.dyn_cast<VectorType>())
-    return VectorType::get(vectorType.getShape(), i1Type);
+    return VectorType::get(vectorType.getShape(), i1Type,
+                           vectorType.getNumScalableDims());
   return i1Type;
 }
 
index 482f0c3..b3c7904 100644 (file)
@@ -25,12 +25,6 @@ using namespace arm_sve;
 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.cpp.inc"
 
 static Type getI1SameShape(Type type);
-static void buildScalableCmpIOp(OpBuilder &build, OperationState &result,
-                                arith::CmpIPredicate predicate, Value lhs,
-                                Value rhs);
-static void buildScalableCmpFOp(OpBuilder &build, OperationState &result,
-                                arith::CmpFPredicate predicate, Value lhs,
-                                Value rhs);
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
@@ -43,31 +37,6 @@ void ArmSVEDialect::initialize() {
 #define GET_OP_LIST
 #include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
       >();
-  addTypes<
-#define GET_TYPEDEF_LIST
-#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
-      >();
-}
-
-//===----------------------------------------------------------------------===//
-// ScalableVectorType
-//===----------------------------------------------------------------------===//
-
-void ScalableVectorType::print(AsmPrinter &printer) const {
-  printer << "<";
-  for (int64_t dim : getShape())
-    printer << dim << 'x';
-  printer << getElementType() << '>';
-}
-
-Type ScalableVectorType::parse(AsmParser &parser) {
-  SmallVector<int64_t> dims;
-  Type eltType;
-  if (parser.parseLess() ||
-      parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
-      parser.parseType(eltType) || parser.parseGreater())
-    return {};
-  return ScalableVectorType::get(eltType.getContext(), dims, eltType);
 }
 
 //===----------------------------------------------------------------------===//
@@ -77,30 +46,8 @@ Type ScalableVectorType::parse(AsmParser &parser) {
 // Return the scalable vector of the same shape and containing i1.
 static Type getI1SameShape(Type type) {
   auto i1Type = IntegerType::get(type.getContext(), 1);
-  if (auto sVectorType = type.dyn_cast<ScalableVectorType>())
-    return ScalableVectorType::get(type.getContext(), sVectorType.getShape(),
-                                   i1Type);
+  if (auto sVectorType = type.dyn_cast<VectorType>())
+    return VectorType::get(sVectorType.getShape(), i1Type,
+                           sVectorType.getNumScalableDims());
   return nullptr;
 }
-
-//===----------------------------------------------------------------------===//
-// CmpFOp
-//===----------------------------------------------------------------------===//
-
-static void buildScalableCmpFOp(OpBuilder &build, OperationState &result,
-                                arith::CmpFPredicate predicate, Value lhs,
-                                Value rhs) {
-  result.addOperands({lhs, rhs});
-  result.types.push_back(getI1SameShape(lhs.getType()));
-  result.addAttribute(ScalableCmpFOp::getPredicateAttrName(),
-                      build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
-}
-
-static void buildScalableCmpIOp(OpBuilder &build, OperationState &result,
-                                arith::CmpIPredicate predicate, Value lhs,
-                                Value rhs) {
-  result.addOperands({lhs, rhs});
-  result.types.push_back(getI1SameShape(lhs.getType()));
-  result.addAttribute(ScalableCmpIOp::getPredicateAttrName(),
-                      build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
-}
index c477019..fdda398 100644 (file)
 using namespace mlir;
 using namespace mlir::arm_sve;
 
-// Extract an LLVM IR type from the LLVM IR dialect type.
-static Type unwrap(Type type) {
-  if (!type)
-    return nullptr;
-  auto *mlirContext = type.getContext();
-  if (!LLVM::isCompatibleType(type))
-    emitError(UnknownLoc::get(mlirContext),
-              "conversion resulted in a non-LLVM type");
-  return type;
-}
-
-static Optional<Type>
-convertScalableVectorTypeToLLVM(ScalableVectorType svType,
-                                LLVMTypeConverter &converter) {
-  auto elementType = unwrap(converter.convertType(svType.getElementType()));
-  if (!elementType)
-    return {};
-
-  auto sVectorType =
-      LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back());
-  return sVectorType;
-}
-
 template <typename OpTy>
 class ForwardOperands : public OpConversionPattern<OpTy> {
   using OpConversionPattern<OpTy>::OpConversionPattern;
@@ -70,22 +47,10 @@ public:
   }
 };
 
-static Optional<Value> addUnrealizedCast(OpBuilder &builder,
-                                         ScalableVectorType svType,
-                                         ValueRange inputs, Location loc) {
-  if (inputs.size() != 1 ||
-      !inputs[0].getType().isa<LLVM::LLVMScalableVectorType>())
-    return Value();
-  return builder.create<UnrealizedConversionCastOp>(loc, svType, inputs)
-      .getResult(0);
-}
-
 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
-using VectorScaleOpLowering =
-    OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>;
 using ScalableMaskedAddIOpLowering =
     OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
                                  ScalableMaskedAddIIntrOp>;
@@ -114,136 +79,10 @@ using ScalableMaskedDivFOpLowering =
     OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
                                  ScalableMaskedDivFIntrOp>;
 
-// Load operation is lowered to code that obtains a pointer to the indexed
-// element and loads from it.
-struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern<ScalableLoadOp> {
-  using ConvertOpToLLVMPattern<ScalableLoadOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(ScalableLoadOp loadOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto type = loadOp.getMemRefType();
-    if (!isConvertibleAndHasIdentityMaps(type))
-      return failure();
-
-    LLVMTypeConverter converter(loadOp.getContext());
-
-    auto resultType = loadOp.result().getType();
-    LLVM::LLVMPointerType llvmDataTypePtr;
-    if (resultType.isa<VectorType>()) {
-      llvmDataTypePtr =
-          LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
-    } else if (resultType.isa<ScalableVectorType>()) {
-      llvmDataTypePtr = LLVM::LLVMPointerType::get(
-          convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
-                                          converter)
-              .getValue());
-    }
-    Value dataPtr = getStridedElementPtr(loadOp.getLoc(), type, adaptor.base(),
-                                         adaptor.index(), rewriter);
-    Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
-        loadOp.getLoc(), llvmDataTypePtr, dataPtr);
-    rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, bitCastedPtr);
-    return success();
-  }
-};
-
-// Store operation is lowered to code that obtains a pointer to the indexed
-// element, and stores the given value to it.
-struct ScalableStoreOpLowering
-    : public ConvertOpToLLVMPattern<ScalableStoreOp> {
-  using ConvertOpToLLVMPattern<ScalableStoreOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(ScalableStoreOp storeOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto type = storeOp.getMemRefType();
-    if (!isConvertibleAndHasIdentityMaps(type))
-      return failure();
-
-    LLVMTypeConverter converter(storeOp.getContext());
-
-    auto resultType = storeOp.value().getType();
-    LLVM::LLVMPointerType llvmDataTypePtr;
-    if (resultType.isa<VectorType>()) {
-      llvmDataTypePtr =
-          LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
-    } else if (resultType.isa<ScalableVectorType>()) {
-      llvmDataTypePtr = LLVM::LLVMPointerType::get(
-          convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
-                                          converter)
-              .getValue());
-    }
-    Value dataPtr = getStridedElementPtr(storeOp.getLoc(), type, adaptor.base(),
-                                         adaptor.index(), rewriter);
-    Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
-        storeOp.getLoc(), llvmDataTypePtr, dataPtr);
-    rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.value(),
-                                               bitCastedPtr);
-    return success();
-  }
-};
-
-static void
-populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter,
-                                         OwningRewritePatternList &patterns) {
-  // clang-format off
-  patterns.add<OneToOneConvertToLLVMPattern<ScalableAddIOp, LLVM::AddOp>,
-               OneToOneConvertToLLVMPattern<ScalableAddFOp, LLVM::FAddOp>,
-               OneToOneConvertToLLVMPattern<ScalableSubIOp, LLVM::SubOp>,
-               OneToOneConvertToLLVMPattern<ScalableSubFOp, LLVM::FSubOp>,
-               OneToOneConvertToLLVMPattern<ScalableMulIOp, LLVM::MulOp>,
-               OneToOneConvertToLLVMPattern<ScalableMulFOp, LLVM::FMulOp>,
-               OneToOneConvertToLLVMPattern<ScalableSDivIOp, LLVM::SDivOp>,
-               OneToOneConvertToLLVMPattern<ScalableUDivIOp, LLVM::UDivOp>,
-               OneToOneConvertToLLVMPattern<ScalableDivFOp, LLVM::FDivOp>
-              >(converter);
-  // clang-format on
-}
-
-static void
-configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) {
-  // clang-format off
-  target.addIllegalOp<ScalableAddIOp,
-                      ScalableAddFOp,
-                      ScalableSubIOp,
-                      ScalableSubFOp,
-                      ScalableMulIOp,
-                      ScalableMulFOp,
-                      ScalableSDivIOp,
-                      ScalableUDivIOp,
-                      ScalableDivFOp>();
-  // clang-format on
-}
-
-static void
-populateSVEMaskGenerationExportPatterns(LLVMTypeConverter &converter,
-                                        OwningRewritePatternList &patterns) {
-  // clang-format off
-  patterns.add<OneToOneConvertToLLVMPattern<ScalableCmpFOp, LLVM::FCmpOp>,
-               OneToOneConvertToLLVMPattern<ScalableCmpIOp, LLVM::ICmpOp>
-              >(converter);
-  // clang-format on
-}
-
-static void
-configureSVEMaskGenerationLegalizations(LLVMConversionTarget &target) {
-  // clang-format off
-  target.addIllegalOp<ScalableCmpFOp,
-                      ScalableCmpIOp>();
-  // clang-format on
-}
-
 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
 void mlir::populateArmSVELegalizeForLLVMExportPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
   // Populate conversion patterns
-  // Remove any ArmSVE-specific types from function signatures and results.
-  populateFuncOpTypeConversionPattern(patterns, converter);
-  converter.addConversion([&converter](ScalableVectorType svType) {
-    return convertScalableVectorTypeToLLVM(svType, converter);
-  });
-  converter.addSourceMaterialization(addUnrealizedCast);
 
   // clang-format off
   patterns.add<ForwardOperands<CallOp>,
@@ -254,7 +93,6 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                SmmlaOpLowering,
                UdotOpLowering,
                UmmlaOpLowering,
-               VectorScaleOpLowering,
                ScalableMaskedAddIOpLowering,
                ScalableMaskedAddFOpLowering,
                ScalableMaskedSubIOpLowering,
@@ -264,11 +102,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                ScalableMaskedSDivIOpLowering,
                ScalableMaskedUDivIOpLowering,
                ScalableMaskedDivFOpLowering>(converter);
-  patterns.add<ScalableLoadOpLowering,
-               ScalableStoreOpLowering>(converter);
   // clang-format on
-  populateBasicSVEArithmeticExportPatterns(converter, patterns);
-  populateSVEMaskGenerationExportPatterns(converter, patterns);
 }
 
 void mlir::configureArmSVELegalizeForExportTarget(
@@ -278,7 +112,6 @@ void mlir::configureArmSVELegalizeForExportTarget(
                     SmmlaIntrOp,
                     UdotIntrOp,
                     UmmlaIntrOp,
-                    VectorScaleIntrOp,
                     ScalableMaskedAddIIntrOp,
                     ScalableMaskedAddFIntrOp,
                     ScalableMaskedSubIIntrOp,
@@ -292,7 +125,6 @@ void mlir::configureArmSVELegalizeForExportTarget(
                       SmmlaOp,
                       UdotOp,
                       UmmlaOp,
-                      VectorScaleOp,
                       ScalableMaskedAddIOp,
                       ScalableMaskedAddFOp,
                       ScalableMaskedSubIOp,
@@ -301,25 +133,6 @@ void mlir::configureArmSVELegalizeForExportTarget(
                       ScalableMaskedMulFOp,
                       ScalableMaskedSDivIOp,
                       ScalableMaskedUDivIOp,
-                      ScalableMaskedDivFOp,
-                      ScalableLoadOp,
-                      ScalableStoreOp>();
+                      ScalableMaskedDivFOp>();
   // clang-format on
-  auto hasScalableVectorType = [](TypeRange types) {
-    for (Type type : types)
-      if (type.isa<arm_sve::ScalableVectorType>())
-        return true;
-    return false;
-  };
-  target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) {
-    return !hasScalableVectorType(op.getType().getInputs()) &&
-           !hasScalableVectorType(op.getType().getResults());
-  });
-  target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>(
-      [hasScalableVectorType](Operation *op) {
-        return !hasScalableVectorType(op->getOperandTypes()) &&
-               !hasScalableVectorType(op->getResultTypes());
-      });
-  configureBasicSVEArithmeticLegalizations(target);
-  configureSVEMaskGenerationLegalizations(target);
 }
index 1688fa4..5b64a29 100644 (file)
@@ -155,12 +155,14 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
     return parser.emitError(trailingTypeLoc,
                             "expected LLVM dialect-compatible type");
   if (LLVM::isCompatibleVectorType(type)) {
-    if (type.isa<LLVM::LLVMScalableVectorType>()) {
-      resultType = LLVM::LLVMScalableVectorType::get(
-          resultType, LLVM::getVectorNumElements(type).getKnownMinValue());
+    if (LLVM::isScalableVectorType(type)) {
+      resultType = LLVM::getVectorType(
+          resultType, LLVM::getVectorNumElements(type).getKnownMinValue(),
+          /*isScalable=*/true);
     } else {
-      resultType = LLVM::getFixedVectorType(
-          resultType, LLVM::getVectorNumElements(type).getFixedValue());
+      resultType = LLVM::getVectorType(
+          resultType, LLVM::getVectorNumElements(type).getFixedValue(),
+          /*isScalable=*/false);
     }
   }
 
index 87dc3b8..b129ed9 100644 (file)
@@ -775,7 +775,12 @@ Type mlir::LLVM::getVectorElementType(Type type) {
 
 llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
   return llvm::TypeSwitch<Type, llvm::ElementCount>(type)
-      .Case<LLVMFixedVectorType, VectorType>([](auto ty) {
+      .Case([](VectorType ty) {
+        if (ty.isScalable())
+          return llvm::ElementCount::getScalable(ty.getNumElements());
+        return llvm::ElementCount::getFixed(ty.getNumElements());
+      })
+      .Case([](LLVMFixedVectorType ty) {
         return llvm::ElementCount::getFixed(ty.getNumElements());
       })
       .Case([](LLVMScalableVectorType ty) {
@@ -786,6 +791,31 @@ llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
       });
 }
 
+bool mlir::LLVM::isScalableVectorType(Type vectorType) {
+  assert(
+      (vectorType
+           .isa<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>()) &&
+      "expected LLVM-compatible vector type");
+  return !vectorType.isa<LLVMFixedVectorType>() &&
+         (vectorType.isa<LLVMScalableVectorType>() ||
+          vectorType.cast<VectorType>().isScalable());
+}
+
+Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
+                               bool isScalable) {
+  bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType);
+  bool useBuiltIn = VectorType::isValidElementType(elementType);
+  (void)useBuiltIn;
+  assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type "
+                                   "to be either builtin or LLVM dialect type");
+  if (useLLVM) {
+    if (isScalable)
+      return LLVMScalableVectorType::get(elementType, numElements);
+    return LLVMFixedVectorType::get(elementType, numElements);
+  }
+  return VectorType::get(numElements, elementType, (unsigned)isScalable);
+}
+
 Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
   bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType);
   bool useBuiltIn = VectorType::isValidElementType(elementType);
@@ -797,6 +827,18 @@ Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
   return VectorType::get(numElements, elementType);
 }
 
+Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
+  bool useLLVM = LLVMScalableVectorType::isValidElementType(elementType);
+  bool useBuiltIn = VectorType::isValidElementType(elementType);
+  (void)useBuiltIn;
+  assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible scalable-vector "
+                                   "type to be either builtin or LLVM dialect "
+                                   "type");
+  if (useLLVM)
+    return LLVMScalableVectorType::get(elementType, numElements);
+  return VectorType::get(numElements, elementType, /*numScalableDims=*/1);
+}
+
 llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
   assert(isCompatibleType(type) &&
          "expected a type compatible with the LLVM dialect");
index 1d045b2..2283174 100644 (file)
@@ -515,7 +515,8 @@ static Type getI1SameShape(Type type) {
   if (type.isa<UnrankedTensorType>())
     return UnrankedTensorType::get(i1Type);
   if (auto vectorType = type.dyn_cast<VectorType>())
-    return VectorType::get(vectorType.getShape(), i1Type);
+    return VectorType::get(vectorType.getShape(), i1Type,
+                           vectorType.getNumScalableDims());
   return i1Type;
 }
 
index 5d5f896..fe214f2 100644 (file)
@@ -1954,8 +1954,19 @@ void AsmPrinter::Impl::printType(Type type) {
       })
       .Case<VectorType>([&](VectorType vectorTy) {
         os << "vector<";
-        for (int64_t dim : vectorTy.getShape())
-          os << dim << 'x';
+        auto vShape = vectorTy.getShape();
+        unsigned lastDim = vShape.size();
+        unsigned lastFixedDim = lastDim - vectorTy.getNumScalableDims();
+        unsigned dimIdx = 0;
+        for (dimIdx = 0; dimIdx < lastFixedDim; dimIdx++)
+          os << vShape[dimIdx] << 'x';
+        if (vectorTy.isScalable()) {
+          os << '[';
+          unsigned secondToLastDim = lastDim - 1;
+          for (; dimIdx < secondToLastDim; dimIdx++)
+            os << vShape[dimIdx] << 'x';
+          os << vShape[dimIdx] << "]x";
+        }
         printType(vectorTy.getElementType());
         os << '>';
       })
index 38c7e16..0b2970b 100644 (file)
@@ -1165,8 +1165,9 @@ static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
   else if (inType.isa<UnrankedTensorType>())
     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
-  else if (inType.isa<VectorType>())
-    newArrayType = VectorType::get(inType.getShape(), newElementType);
+  else if (auto vType = inType.dyn_cast<VectorType>())
+    newArrayType = VectorType::get(vType.getShape(), newElementType,
+                                   vType.getNumScalableDims());
   else
     assert(newArrayType && "Unhandled tensor type");
 
index 8e408e4..e965afb 100644 (file)
@@ -294,8 +294,8 @@ ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) {
   if (isa<TensorType>())
     return RankedTensorType::get(shape, elementType);
 
-  if (isa<VectorType>())
-    return VectorType::get(shape, elementType);
+  if (auto vecTy = dyn_cast<VectorType>())
+    return VectorType::get(shape, elementType, vecTy.getNumScalableDims());
 
   llvm_unreachable("Unhandled ShapedType clone case");
 }
@@ -317,8 +317,8 @@ ShapedType ShapedType::clone(ArrayRef<int64_t> shape) {
   if (isa<TensorType>())
     return RankedTensorType::get(shape, getElementType());
 
-  if (isa<VectorType>())
-    return VectorType::get(shape, getElementType());
+  if (auto vecTy = dyn_cast<VectorType>())
+    return VectorType::get(shape, getElementType(), vecTy.getNumScalableDims());
 
   llvm_unreachable("Unhandled ShapedType clone case");
 }
@@ -340,8 +340,8 @@ ShapedType ShapedType::clone(Type elementType) {
     return UnrankedTensorType::get(elementType);
   }
 
-  if (isa<VectorType>())
-    return VectorType::get(getShape(), elementType);
+  if (auto vecTy = dyn_cast<VectorType>())
+    return VectorType::get(getShape(), elementType, vecTy.getNumScalableDims());
 
   llvm_unreachable("Unhandled ShapedType clone hit");
 }
@@ -441,7 +441,8 @@ bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
 //===----------------------------------------------------------------------===//
 
 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
-                                 ArrayRef<int64_t> shape, Type elementType) {
+                                 ArrayRef<int64_t> shape, Type elementType,
+                                 unsigned numScalableDims) {
   if (!isValidElementType(elementType))
     return emitError()
            << "vector elements must be int/index/float type but got "
@@ -460,10 +461,10 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
     return VectorType();
   if (auto et = getElementType().dyn_cast<IntegerType>())
     if (auto scaledEt = et.scaleElementBitwidth(scale))
-      return VectorType::get(getShape(), scaledEt);
+      return VectorType::get(getShape(), scaledEt, getNumScalableDims());
   if (auto et = getElementType().dyn_cast<FloatType>())
     if (auto scaledEt = et.scaleElementBitwidth(scale))
-      return VectorType::get(getShape(), scaledEt);
+      return VectorType::get(getShape(), scaledEt, getNumScalableDims());
   return VectorType();
 }
 
index da0c81c..8bdf248 100644 (file)
@@ -196,8 +196,11 @@ public:
 
   /// Parse a vector type.
   VectorType parseVectorType();
+  ParseResult parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
+                                       unsigned &numScalableDims);
   ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
                                        bool allowDynamic = true);
+  ParseResult parseIntegerInDimensionList(int64_t &value);
   ParseResult parseXInDimensionList();
 
   /// Parse strided layout specification.
index 1f505eb..5ff6e41 100644 (file)
@@ -13,6 +13,7 @@
 #include "Parser.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/TensorEncoding.h"
 
 using namespace mlir;
@@ -442,8 +443,9 @@ Type Parser::parseTupleType() {
 
 /// Parse a vector type.
 ///
-///   vector-type ::= `vector` `<` static-dimension-list type `>`
-///   static-dimension-list ::= (decimal-literal `x`)*
+/// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
+/// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
+/// static-dim-list ::= decimal-literal (`x` decimal-literal)*
 ///
 VectorType Parser::parseVectorType() {
   consumeToken(Token::kw_vector);
@@ -452,7 +454,8 @@ VectorType Parser::parseVectorType() {
     return nullptr;
 
   SmallVector<int64_t, 4> dimensions;
-  if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false))
+  unsigned numScalableDims;
+  if (parseVectorDimensionList(dimensions, numScalableDims))
     return nullptr;
   if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
     return emitError(getToken().getLoc(),
@@ -464,11 +467,59 @@ VectorType Parser::parseVectorType() {
   auto elementType = parseType();
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
+
   if (!VectorType::isValidElementType(elementType))
     return emitError(typeLoc, "vector elements must be int/index/float type"),
            nullptr;
 
-  return VectorType::get(dimensions, elementType);
+  return VectorType::get(dimensions, elementType, numScalableDims);
+}
+
+/// Parse a dimension list in a vector type. This populates the dimension list,
+/// and returns the number of scalable dimensions in `numScalableDims`.
+///
+/// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
+/// static-dim-list ::= decimal-literal (`x` decimal-literal)*
+///
+ParseResult
+Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
+                                 unsigned &numScalableDims) {
+  numScalableDims = 0;
+  // If there is a set of fixed-length dimensions, consume it
+  while (getToken().is(Token::integer)) {
+    int64_t value;
+    if (parseIntegerInDimensionList(value))
+      return failure();
+    dimensions.push_back(value);
+    // Make sure we have an 'x' or something like 'xbf32'.
+    if (parseXInDimensionList())
+      return failure();
+  }
+  // If there is a set of scalable dimensions, consume it
+  if (consumeIf(Token::l_square)) {
+    while (getToken().is(Token::integer)) {
+      int64_t value;
+      if (parseIntegerInDimensionList(value))
+        return failure();
+      dimensions.push_back(value);
+      numScalableDims++;
+      // Check if we have reached the end of the scalable dimension list
+      if (consumeIf(Token::r_square)) {
+        // Make sure we have something like 'xbf32'.
+        if (parseXInDimensionList())
+          return failure();
+        return success();
+      }
+      // Make sure we have an 'x'
+      if (parseXInDimensionList())
+        return failure();
+    }
+    // If we make it here, we've finished parsing the dimension list
+    // without finding ']' closing the set of scalable dimensions
+    return emitError("missing ']' closing set of scalable dimensions");
+  }
+
+  return success();
 }
 
 /// Parse a dimension list of a tensor or memref type.  This populates the
@@ -490,28 +541,11 @@ Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
         return emitError("expected static shape");
       dimensions.push_back(-1);
     } else {
-      // Hexadecimal integer literals (starting with `0x`) are not allowed in
-      // aggregate type declarations.  Therefore, `0xf32` should be processed as
-      // a sequence of separate elements `0`, `x`, `f32`.
-      if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
-        // We can get here only if the token is an integer literal.  Hexadecimal
-        // integer literals can only start with `0x` (`1x` wouldn't lex as a
-        // literal, just `1` would, at which point we don't get into this
-        // branch).
-        assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
-        dimensions.push_back(0);
-        state.lex.resetPointer(getTokenSpelling().data() + 1);
-        consumeToken();
-      } else {
-        // Make sure this integer value is in bound and valid.
-        Optional<uint64_t> dimension = getToken().getUInt64IntegerValue();
-        if (!dimension || *dimension > std::numeric_limits<int64_t>::max())
-          return emitError("invalid dimension");
-        dimensions.push_back((int64_t)dimension.getValue());
-        consumeToken(Token::integer);
-      }
+      int64_t value;
+      if (parseIntegerInDimensionList(value))
+        return failure();
+      dimensions.push_back(value);
     }
-
     // Make sure we have an 'x' or something like 'xbf32'.
     if (parseXInDimensionList())
       return failure();
@@ -520,6 +554,30 @@ Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
   return success();
 }
 
+ParseResult Parser::parseIntegerInDimensionList(int64_t &value) {
+  // Hexadecimal integer literals (starting with `0x`) are not allowed in
+  // aggregate type declarations.  Therefore, `0xf32` should be processed as
+  // a sequence of separate elements `0`, `x`, `f32`.
+  if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
+    // We can get here only if the token is an integer literal.  Hexadecimal
+    // integer literals can only start with `0x` (`1x` wouldn't lex as a
+    // literal, just `1` would, at which point we don't get into this
+    // branch).
+    assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
+    value = 0;
+    state.lex.resetPointer(getTokenSpelling().data() + 1);
+    consumeToken();
+  } else {
+    // Make sure this integer value is in bound and valid.
+    Optional<uint64_t> dimension = getToken().getUInt64IntegerValue();
+    if (!dimension || *dimension > std::numeric_limits<int64_t>::max())
+      return emitError("invalid dimension");
+    value = (int64_t)dimension.getValue();
+    consumeToken(Token::integer);
+  }
+  return success();
+}
+
 /// Parse an 'x' token in a dimension list, handling the case where the x is
 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
 /// token.
index 5d8345d..0759340 100644 (file)
@@ -242,10 +242,14 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
     if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
       elementType = arrayTy->getElementType();
       numElements = arrayTy->getNumElements();
+    } else if (auto fVectorTy = dyn_cast<llvm::FixedVectorType>(llvmType)) {
+      elementType = fVectorTy->getElementType();
+      numElements = fVectorTy->getNumElements();
+    } else if (auto sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) {
+      elementType = sVectorTy->getElementType();
+      numElements = sVectorTy->getMinNumElements();
     } else {
-      auto *vectorTy = cast<llvm::FixedVectorType>(llvmType);
-      elementType = vectorTy->getElementType();
-      numElements = vectorTy->getNumElements();
+      llvm_unreachable("unrecognized constant vector type");
     }
     // Splat value is a scalar. Extract it only if the element type is not
     // another sequence type. The recursion terminates because each step removes
index aadd72a..3d67984 100644 (file)
@@ -137,6 +137,9 @@ private:
   llvm::Type *translate(VectorType type) {
     assert(LLVM::isCompatibleVectorType(type) &&
            "expected compatible with LLVM vector type");
+    if (type.isScalable())
+      return llvm::ScalableVectorType::get(translateType(type.getElementType()),
+                                           type.getNumElements());
     return llvm::FixedVectorType::get(translateType(type.getElementType()),
                                       type.getNumElements());
   }
index 6907d7b..3a79342 100644 (file)
@@ -19,6 +19,12 @@ func @test_addi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8
   return %0 : vector<8xi64>
 }
 
+// CHECK-LABEL: test_addi_scalable_vector
+func @test_addi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+  %0 = arith.addi %arg0, %arg1 : vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_subi
 func @test_subi(%arg0 : i64, %arg1 : i64) -> i64 {
   %0 = arith.subi %arg0, %arg1 : i64
@@ -37,6 +43,12 @@ func @test_subi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8
   return %0 : vector<8xi64>
 }
 
+// CHECK-LABEL: test_subi_scalable_vector
+func @test_subi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+  %0 = arith.subi %arg0, %arg1 : vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_muli
 func @test_muli(%arg0 : i64, %arg1 : i64) -> i64 {
   %0 = arith.muli %arg0, %arg1 : i64
@@ -55,6 +67,12 @@ func @test_muli_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8
   return %0 : vector<8xi64>
 }
 
+// CHECK-LABEL: test_muli_scalable_vector
+func @test_muli_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+  %0 = arith.muli %arg0, %arg1 : vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_divui
 func @test_divui(%arg0 : i64, %arg1 : i64) -> i64 {
   %0 = arith.divui %arg0, %arg1 : i64
@@ -73,6 +91,12 @@ func @test_divui_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<
   return %0 : vector<8xi64>
 }
 
+// CHECK-LABEL: test_divui_scalable_vector
+func @test_divui_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+  %0 = arith.divui %arg0, %arg1 : vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_divsi
 func @test_divsi(%arg0 : i64, %arg1 : i64) -> i64 {
   %0 = arith.divsi %arg0, %arg1 : i64
@@ -91,6 +115,12 @@ func @test_divsi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<
   return %0 : vector<8xi64>
 }
 
+// CHECK-LABEL: test_divsi_scalable_vector
+func @test_divsi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+  %0 = arith.divsi %arg0, %arg1 : vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_remui
 func @test_remui(%arg0 : i64, %arg1 : i64) -> i64 {
   %0 = arith.remui %arg0, %arg1 : i64
@@ -109,6 +139,12 @@ func @test_remui_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<
   return %0 : vector<8xi64>
 }
 
+// CHECK-LABEL: test_remui_scalable_vector
+func @test_remui_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+  %0 = arith.remui %arg0, %arg1 : vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_remsi
 func @test_remsi(%arg0 : i64, %arg1 : i64) -> i64 {
   %0 = arith.remsi %arg0, %arg1 : i64
@@ -127,6 +163,12 @@ func @test_remsi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<
   return %0 : vector<8xi64>
 }
 
+// CHECK-LABEL: test_remsi_scalable_vector
+func @test_remsi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+  %0 = arith.remsi %arg0, %arg1 : vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_andi
 func @test_andi(%arg0 : i64, %arg1 : i64) -> i64 {
   %0 = arith.andi %arg0, %arg1 : i64
@@ -145,6 +187,12 @@ func @test_andi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8
   return %0 : vector<8xi64>
 }
 
+// CHECK-LABEL: test_andi_scalable_vector
+func @test_andi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+  %0 = arith.andi %arg0, %arg1 : vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_ori
 func @test_ori(%arg0 : i64, %arg1 : i64) -> i64 {
   %0 = arith.ori %arg0, %arg1 : i64
@@ -163,6 +211,12 @@ func @test_ori_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8x
   return %0 : vector<8xi64>
 }
 
+// CHECK-LABEL: test_ori_scalable_vector
+func @test_ori_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+  %0 = arith.ori %arg0, %arg1 : vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_xori
 func @test_xori(%arg0 : i64, %arg1 : i64) -> i64 {
   %0 = arith.xori %arg0, %arg1 : i64
@@ -181,6 +235,12 @@ func @test_xori_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8
   return %0 : vector<8xi64>
 }
 
+// CHECK-LABEL: test_xori_scalable_vector
+func @test_xori_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+  %0 = arith.xori %arg0, %arg1 : vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_ceildivsi
 func @test_ceildivsi(%arg0 : i64, %arg1 : i64) -> i64 {
   %0 = arith.ceildivsi %arg0, %arg1 : i64
@@ -199,6 +259,12 @@ func @test_ceildivsi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vec
   return %0 : vector<8xi64>
 }
 
+// CHECK-LABEL: test_ceildivsi_scalable_vector
+func @test_ceildivsi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+  %0 = arith.ceildivsi %arg0, %arg1 : vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_floordivsi
 func @test_floordivsi(%arg0 : i64, %arg1 : i64) -> i64 {
   %0 = arith.floordivsi %arg0, %arg1 : i64
@@ -217,6 +283,12 @@ func @test_floordivsi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> ve
   return %0 : vector<8xi64>
 }
 
+// CHECK-LABEL: test_floordivsi_scalable_vector
+func @test_floordivsi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+  %0 = arith.floordivsi %arg0, %arg1 : vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_shli
 func @test_shli(%arg0 : i64, %arg1 : i64) -> i64 {
   %0 = arith.shli %arg0, %arg1 : i64
@@ -235,6 +307,12 @@ func @test_shli_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8
   return %0 : vector<8xi64>
 }
 
+// CHECK-LABEL: test_shli_scalable_vector
+func @test_shli_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+  %0 = arith.shli %arg0, %arg1 : vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_shrui
 func @test_shrui(%arg0 : i64, %arg1 : i64) -> i64 {
   %0 = arith.shrui %arg0, %arg1 : i64
@@ -253,6 +331,12 @@ func @test_shrui_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<
   return %0 : vector<8xi64>
 }
 
+// CHECK-LABEL: test_shrui_scalable_vector
+func @test_shrui_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+  %0 = arith.shrui %arg0, %arg1 : vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_shrsi
 func @test_shrsi(%arg0 : i64, %arg1 : i64) -> i64 {
   %0 = arith.shrsi %arg0, %arg1 : i64
@@ -271,6 +355,12 @@ func @test_shrsi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<
   return %0 : vector<8xi64>
 }
 
+// CHECK-LABEL: test_shrsi_scalable_vector
+func @test_shrsi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+  %0 = arith.shrsi %arg0, %arg1 : vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_negf
 func @test_negf(%arg0 : f64) -> f64 {
   %0 = arith.negf %arg0 : f64
@@ -289,6 +379,12 @@ func @test_negf_vector(%arg0 : vector<8xf64>) -> vector<8xf64> {
   return %0 : vector<8xf64>
 }
 
+// CHECK-LABEL: test_negf_scalable_vector
+func @test_negf_scalable_vector(%arg0 : vector<[8]xf64>) -> vector<[8]xf64> {
+  %0 = arith.negf %arg0 : vector<[8]xf64>
+  return %0 : vector<[8]xf64>
+}
+
 // CHECK-LABEL: test_addf
 func @test_addf(%arg0 : f64, %arg1 : f64) -> f64 {
   %0 = arith.addf %arg0, %arg1 : f64
@@ -307,6 +403,12 @@ func @test_addf_vector(%arg0 : vector<8xf64>, %arg1 : vector<8xf64>) -> vector<8
   return %0 : vector<8xf64>
 }
 
+// CHECK-LABEL: test_addf_scalable_vector
+func @test_addf_scalable_vector(%arg0 : vector<[8]xf64>, %arg1 : vector<[8]xf64>) -> vector<[8]xf64> {
+  %0 = arith.addf %arg0, %arg1 : vector<[8]xf64>
+  return %0 : vector<[8]xf64>
+}
+
 // CHECK-LABEL: test_subf
 func @test_subf(%arg0 : f64, %arg1 : f64) -> f64 {
   %0 = arith.subf %arg0, %arg1 : f64
@@ -325,6 +427,12 @@ func @test_subf_vector(%arg0 : vector<8xf64>, %arg1 : vector<8xf64>) -> vector<8
   return %0 : vector<8xf64>
 }
 
+// CHECK-LABEL: test_subf_scalable_vector
+func @test_subf_scalable_vector(%arg0 : vector<[8]xf64>, %arg1 : vector<[8]xf64>) -> vector<[8]xf64> {
+  %0 = arith.subf %arg0, %arg1 : vector<[8]xf64>
+  return %0 : vector<[8]xf64>
+}
+
 // CHECK-LABEL: test_mulf
 func @test_mulf(%arg0 : f64, %arg1 : f64) -> f64 {
   %0 = arith.mulf %arg0, %arg1 : f64
@@ -343,6 +451,12 @@ func @test_mulf_vector(%arg0 : vector<8xf64>, %arg1 : vector<8xf64>) -> vector<8
   return %0 : vector<8xf64>
 }
 
+// CHECK-LABEL: test_mulf_scalable_vector
+func @test_mulf_scalable_vector(%arg0 : vector<[8]xf64>, %arg1 : vector<[8]xf64>) -> vector<[8]xf64> {
+  %0 = arith.mulf %arg0, %arg1 : vector<[8]xf64>
+  return %0 : vector<[8]xf64>
+}
+
 // CHECK-LABEL: test_divf
 func @test_divf(%arg0 : f64, %arg1 : f64) -> f64 {
   %0 = arith.divf %arg0, %arg1 : f64
@@ -361,6 +475,12 @@ func @test_divf_vector(%arg0 : vector<8xf64>, %arg1 : vector<8xf64>) -> vector<8
   return %0 : vector<8xf64>
 }
 
+// CHECK-LABEL: test_divf_scalable_vector
+func @test_divf_scalable_vector(%arg0 : vector<[8]xf64>, %arg1 : vector<[8]xf64>) -> vector<[8]xf64> {
+  %0 = arith.divf %arg0, %arg1 : vector<[8]xf64>
+  return %0 : vector<[8]xf64>
+}
+
 // CHECK-LABEL: test_remf
 func @test_remf(%arg0 : f64, %arg1 : f64) -> f64 {
   %0 = arith.remf %arg0, %arg1 : f64
@@ -379,6 +499,12 @@ func @test_remf_vector(%arg0 : vector<8xf64>, %arg1 : vector<8xf64>) -> vector<8
   return %0 : vector<8xf64>
 }
 
+// CHECK-LABEL: test_remf_scalable_vector
+func @test_remf_scalable_vector(%arg0 : vector<[8]xf64>, %arg1 : vector<[8]xf64>) -> vector<[8]xf64> {
+  %0 = arith.remf %arg0, %arg1 : vector<[8]xf64>
+  return %0 : vector<[8]xf64>
+}
+
 // CHECK-LABEL: test_extui
 func @test_extui(%arg0 : i32) -> i64 {
   %0 = arith.extui %arg0 : i32 to i64
@@ -397,6 +523,12 @@ func @test_extui_vector(%arg0 : vector<8xi32>) -> vector<8xi64> {
   return %0 : vector<8xi64>
 }
 
+// CHECK-LABEL: test_extui_scalable_vector
+func @test_extui_scalable_vector(%arg0 : vector<[8]xi32>) -> vector<[8]xi64> {
+  %0 = arith.extui %arg0 : vector<[8]xi32> to vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_extsi
 func @test_extsi(%arg0 : i32) -> i64 {
   %0 = arith.extsi %arg0 : i32 to i64
@@ -415,6 +547,12 @@ func @test_extsi_vector(%arg0 : vector<8xi32>) -> vector<8xi64> {
   return %0 : vector<8xi64>
 }
 
+// CHECK-LABEL: test_extsi_scalable_vector
+func @test_extsi_scalable_vector(%arg0 : vector<[8]xi32>) -> vector<[8]xi64> {
+  %0 = arith.extsi %arg0 : vector<[8]xi32> to vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_extf
 func @test_extf(%arg0 : f32) -> f64 {
   %0 = arith.extf %arg0 : f32 to f64
@@ -433,6 +571,12 @@ func @test_extf_vector(%arg0 : vector<8xf32>) -> vector<8xf64> {
   return %0 : vector<8xf64>
 }
 
+// CHECK-LABEL: test_extf_scalable_vector
+func @test_extf_scalable_vector(%arg0 : vector<[8]xf32>) -> vector<[8]xf64> {
+  %0 = arith.extf %arg0 : vector<[8]xf32> to vector<[8]xf64>
+  return %0 : vector<[8]xf64>
+}
+
 // CHECK-LABEL: test_trunci
 func @test_trunci(%arg0 : i32) -> i16 {
   %0 = arith.trunci %arg0 : i32 to i16
@@ -451,6 +595,12 @@ func @test_trunci_vector(%arg0 : vector<8xi32>) -> vector<8xi16> {
   return %0 : vector<8xi16>
 }
 
+// CHECK-LABEL: test_trunci_scalable_vector
+func @test_trunci_scalable_vector(%arg0 : vector<[8]xi32>) -> vector<[8]xi16> {
+  %0 = arith.trunci %arg0 : vector<[8]xi32> to vector<[8]xi16>
+  return %0 : vector<[8]xi16>
+}
+
 // CHECK-LABEL: test_truncf
 func @test_truncf(%arg0 : f32) -> bf16 {
   %0 = arith.truncf %arg0 : f32 to bf16
@@ -469,6 +619,12 @@ func @test_truncf_vector(%arg0 : vector<8xf32>) -> vector<8xbf16> {
   return %0 : vector<8xbf16>
 }
 
+// CHECK-LABEL: test_truncf_scalable_vector
+func @test_truncf_scalable_vector(%arg0 : vector<[8]xf32>) -> vector<[8]xbf16> {
+  %0 = arith.truncf %arg0 : vector<[8]xf32> to vector<[8]xbf16>
+  return %0 : vector<[8]xbf16>
+}
+
 // CHECK-LABEL: test_uitofp
 func @test_uitofp(%arg0 : i32) -> f32 {
   %0 = arith.uitofp %arg0 : i32 to f32
@@ -487,6 +643,12 @@ func @test_uitofp_vector(%arg0 : vector<8xi32>) -> vector<8xf32> {
   return %0 : vector<8xf32>
 }
 
+// CHECK-LABEL: test_uitofp_scalable_vector
+func @test_uitofp_scalable_vector(%arg0 : vector<[8]xi32>) -> vector<[8]xf32> {
+  %0 = arith.uitofp %arg0 : vector<[8]xi32> to vector<[8]xf32>
+  return %0 : vector<[8]xf32>
+}
+
 // CHECK-LABEL: test_sitofp
 func @test_sitofp(%arg0 : i16) -> f64 {
   %0 = arith.sitofp %arg0 : i16 to f64
@@ -505,6 +667,12 @@ func @test_sitofp_vector(%arg0 : vector<8xi16>) -> vector<8xf64> {
   return %0 : vector<8xf64>
 }
 
+// CHECK-LABEL: test_sitofp_scalable_vector
+func @test_sitofp_scalable_vector(%arg0 : vector<[8]xi16>) -> vector<[8]xf64> {
+  %0 = arith.sitofp %arg0 : vector<[8]xi16> to vector<[8]xf64>
+  return %0 : vector<[8]xf64>
+}
+
 // CHECK-LABEL: test_fptoui
 func @test_fptoui(%arg0 : bf16) -> i8 {
   %0 = arith.fptoui %arg0 : bf16 to i8
@@ -523,6 +691,12 @@ func @test_fptoui_vector(%arg0 : vector<8xbf16>) -> vector<8xi8> {
  return %0 : vector<8xi8>
 }
 
+// CHECK-LABEL: test_fptoui_scalable_vector
+func @test_fptoui_scalable_vector(%arg0 : vector<[8]xbf16>) -> vector<[8]xi8> {
+  %0 = arith.fptoui %arg0 : vector<[8]xbf16> to vector<[8]xi8>
+  return %0 : vector<[8]xi8>
+}
+
 // CHECK-LABEL: test_fptosi
 func @test_fptosi(%arg0 : f64) -> i64 {
   %0 = arith.fptosi %arg0 : f64 to i64
@@ -541,6 +715,12 @@ func @test_fptosi_vector(%arg0 : vector<8xf64>) -> vector<8xi64> {
  return %0 : vector<8xi64>
 }
 
+// CHECK-LABEL: test_fptosi_scalable_vector
+func @test_fptosi_scalable_vector(%arg0 : vector<[8]xf64>) -> vector<[8]xi64> {
+  %0 = arith.fptosi %arg0 : vector<[8]xf64> to vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_index_cast0
 func @test_index_cast0(%arg0 : i32) -> index {
   %0 = arith.index_cast %arg0 : i32 to index
@@ -559,6 +739,12 @@ func @test_index_cast_vector0(%arg0 : vector<8xi32>) -> vector<8xindex> {
   return %0 : vector<8xindex>
 }
 
+// CHECK-LABEL: test_index_cast_scalable_vector0
+func @test_index_cast_scalable_vector0(%arg0 : vector<[8]xi32>) -> vector<[8]xindex> {
+  %0 = arith.index_cast %arg0 : vector<[8]xi32> to vector<[8]xindex>
+  return %0 : vector<[8]xindex>
+}
+
 // CHECK-LABEL: test_index_cast1
 func @test_index_cast1(%arg0 : index) -> i64 {
   %0 = arith.index_cast %arg0 : index to i64
@@ -577,6 +763,12 @@ func @test_index_cast_vector1(%arg0 : vector<8xindex>) -> vector<8xi64> {
   return %0 : vector<8xi64>
 }
 
+// CHECK-LABEL: test_index_cast_scalable_vector1
+func @test_index_cast_scalable_vector1(%arg0 : vector<[8]xindex>) -> vector<[8]xi64> {
+  %0 = arith.index_cast %arg0 : vector<[8]xindex> to vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_bitcast0
 func @test_bitcast0(%arg0 : i64) -> f64 {
   %0 = arith.bitcast %arg0 : i64 to f64
@@ -595,6 +787,12 @@ func @test_bitcast_vector0(%arg0 : vector<8xi64>) -> vector<8xf64> {
   return %0 : vector<8xf64>
 }
 
+// CHECK-LABEL: test_bitcast_scalable_vector0
+func @test_bitcast_scalable_vector0(%arg0 : vector<[8]xi64>) -> vector<[8]xf64> {
+  %0 = arith.bitcast %arg0 : vector<[8]xi64> to vector<[8]xf64>
+  return %0 : vector<[8]xf64>
+}
+
 // CHECK-LABEL: test_bitcast1
 func @test_bitcast1(%arg0 : f32) -> i32 {
   %0 = arith.bitcast %arg0 : f32 to i32
@@ -613,6 +811,12 @@ func @test_bitcast_vector1(%arg0 : vector<8xf32>) -> vector<8xi32> {
   return %0 : vector<8xi32>
 }
 
+// CHECK-LABEL: test_bitcast_scalable_vector1
+func @test_bitcast_scalable_vector1(%arg0 : vector<[8]xf32>) -> vector<[8]xi32> {
+  %0 = arith.bitcast %arg0 : vector<[8]xf32> to vector<[8]xi32>
+  return %0 : vector<[8]xi32>
+}
+
 // CHECK-LABEL: test_cmpi
 func @test_cmpi(%arg0 : i64, %arg1 : i64) -> i1 {
   %0 = arith.cmpi ne, %arg0, %arg1 : i64
@@ -631,6 +835,12 @@ func @test_cmpi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8
   return %0 : vector<8xi1>
 }
 
+// CHECK-LABEL: test_cmpi_scalable_vector
+func @test_cmpi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi1> {
+  %0 = arith.cmpi ult, %arg0, %arg1 : vector<[8]xi64>
+  return %0 : vector<[8]xi1>
+}
+
 // CHECK-LABEL: test_cmpi_vector_0d
 func @test_cmpi_vector_0d(%arg0 : vector<i64>, %arg1 : vector<i64>) -> vector<i1> {
   %0 = arith.cmpi ult, %arg0, %arg1 : vector<i64>
@@ -655,6 +865,12 @@ func @test_cmpf_vector(%arg0 : vector<8xf64>, %arg1 : vector<8xf64>) -> vector<8
   return %0 : vector<8xi1>
 }
 
+// CHECK-LABEL: test_cmpf_scalable_vector
+func @test_cmpf_scalable_vector(%arg0 : vector<[8]xf64>, %arg1 : vector<[8]xf64>) -> vector<[8]xi1> {
+  %0 = arith.cmpf ult, %arg0, %arg1 : vector<[8]xf64>
+  return %0 : vector<[8]xi1>
+}
+
 // CHECK-LABEL: test_index_cast
 func @test_index_cast(%arg0 : index) -> i64 {
   %0 = arith.index_cast %arg0 : index to i64
@@ -713,9 +929,11 @@ func @test_constant() -> () {
 
 // CHECK-LABEL: func @maximum
 func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>,
+               %sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>,
                %f1: f32, %f2: f32,
                %i1: i32, %i2: i32) {
   %max_vector = arith.maxf %v1, %v2 : vector<4xf32>
+  %max_scalable_vector = arith.maxf %sv1, %sv2 : vector<[4]xf32>
   %max_float = arith.maxf %f1, %f2 : f32
   %max_signed = arith.maxsi %i1, %i2 : i32
   %max_unsigned = arith.maxui %i1, %i2 : i32
@@ -724,9 +942,11 @@ func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>,
 
 // CHECK-LABEL: func @minimum
 func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>,
+               %sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>,
                %f1: f32, %f2: f32,
                %i1: i32, %i2: i32) {
   %min_vector = arith.minf %v1, %v2 : vector<4xf32>
+  %min_scalable_vector = arith.minf %sv1, %sv2 : vector<[4]xf32>
   %min_float = arith.minf %f1, %f2 : f32
   %min_signed = arith.minsi %i1, %i2 : i32
   %min_unsigned = arith.minui %i1, %i2 : i32
index 7fd2a63..ae495d6 100644 (file)
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" -convert-std-to-llvm | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" -convert-std-to-llvm -reconcile-unrealized-casts | mlir-opt | FileCheck %s
 
-func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
-                   %b: !arm_sve.vector<16xi8>,
-                   %c: !arm_sve.vector<4xi32>)
-    -> !arm_sve.vector<4xi32> {
+func @arm_sve_sdot(%a: vector<[16]xi8>,
+                   %b: vector<[16]xi8>,
+                   %c: vector<[4]xi32>)
+    -> vector<[4]xi32> {
   // CHECK: arm_sve.intr.sdot
   %0 = arm_sve.sdot %c, %a, %b :
-               !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
-  return %0 : !arm_sve.vector<4xi32>
+               vector<[16]xi8> to vector<[4]xi32>
+  return %0 : vector<[4]xi32>
 }
 
-func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
-                    %b: !arm_sve.vector<16xi8>,
-                    %c: !arm_sve.vector<4xi32>)
-    -> !arm_sve.vector<4xi32> {
+func @arm_sve_smmla(%a: vector<[16]xi8>,
+                    %b: vector<[16]xi8>,
+                    %c: vector<[4]xi32>)
+    -> vector<[4]xi32> {
   // CHECK: arm_sve.intr.smmla
   %0 = arm_sve.smmla %c, %a, %b :
-               !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
-  return %0 : !arm_sve.vector<4xi32>
+               vector<[16]xi8> to vector<[4]xi32>
+  return %0 : vector<[4]xi32>
 }
 
-func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
-                   %b: !arm_sve.vector<16xi8>,
-                   %c: !arm_sve.vector<4xi32>)
-    -> !arm_sve.vector<4xi32> {
+func @arm_sve_udot(%a: vector<[16]xi8>,
+                   %b: vector<[16]xi8>,
+                   %c: vector<[4]xi32>)
+    -> vector<[4]xi32> {
   // CHECK: arm_sve.intr.udot
   %0 = arm_sve.udot %c, %a, %b :
-               !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
-  return %0 : !arm_sve.vector<4xi32>
+               vector<[16]xi8> to vector<[4]xi32>
+  return %0 : vector<[4]xi32>
 }
 
-func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
-                    %b: !arm_sve.vector<16xi8>,
-                    %c: !arm_sve.vector<4xi32>)
-    -> !arm_sve.vector<4xi32> {
+func @arm_sve_ummla(%a: vector<[16]xi8>,
+                    %b: vector<[16]xi8>,
+                    %c: vector<[4]xi32>)
+    -> vector<[4]xi32> {
   // CHECK: arm_sve.intr.ummla
   %0 = arm_sve.ummla %c, %a, %b :
-               !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
-  return %0 : !arm_sve.vector<4xi32>
+               vector<[16]xi8> to vector<[4]xi32>
+  return %0 : vector<[4]xi32>
 }
 
-func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>,
-                     %b: !arm_sve.vector<4xi32>,
-                     %c: !arm_sve.vector<4xi32>,
-                     %d: !arm_sve.vector<4xi32>,
-                     %e: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
-  // CHECK: llvm.mul {{.*}}: !llvm.vec<? x 4 x i32>
-  %0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32>
-  // CHECK: llvm.add {{.*}}: !llvm.vec<? x 4 x i32>
-  %1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32>
-  // CHECK: llvm.sub {{.*}}: !llvm.vec<? x 4 x i32>
-  %2 = arm_sve.subi %1, %d : !arm_sve.vector<4xi32>
-  // CHECK: llvm.sdiv {{.*}}: !llvm.vec<? x 4 x i32>
-  %3 = arm_sve.divi_signed %2, %e : !arm_sve.vector<4xi32>
-  // CHECK: llvm.udiv {{.*}}: !llvm.vec<? x 4 x i32>
-  %4 = arm_sve.divi_unsigned %2, %e : !arm_sve.vector<4xi32>
-  return %4 : !arm_sve.vector<4xi32>
+func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
+                            %b: vector<[4]xi32>,
+                            %c: vector<[4]xi32>,
+                            %d: vector<[4]xi32>,
+                            %e: vector<[4]xi32>,
+                            %mask: vector<[4]xi1>
+                            ) -> vector<[4]xi32> {
+  // CHECK: arm_sve.intr.add{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32>
+  %0 = arm_sve.masked.addi %mask, %a, %b : vector<[4]xi1>,
+                                           vector<[4]xi32>
+  // CHECK: arm_sve.intr.sub{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32>
+  %1 = arm_sve.masked.subi %mask, %0, %c : vector<[4]xi1>,
+                                           vector<[4]xi32>
+  // CHECK: arm_sve.intr.mul{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32>
+  %2 = arm_sve.masked.muli %mask, %1, %d : vector<[4]xi1>,
+                                           vector<[4]xi32>
+  // CHECK: arm_sve.intr.sdiv{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32>
+  %3 = arm_sve.masked.divi_signed %mask, %2, %e : vector<[4]xi1>,
+                                                  vector<[4]xi32>
+  // CHECK: arm_sve.intr.udiv{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32>
+  %4 = arm_sve.masked.divi_unsigned %mask, %3, %e : vector<[4]xi1>,
+                                                    vector<[4]xi32>
+  return %4 : vector<[4]xi32>
 }
 
-func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>,
-                     %b: !arm_sve.vector<4xf32>,
-                     %c: !arm_sve.vector<4xf32>,
-                     %d: !arm_sve.vector<4xf32>,
-                     %e: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> {
-  // CHECK: llvm.fmul {{.*}}: !llvm.vec<? x 4 x f32>
-  %0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32>
-  // CHECK: llvm.fadd {{.*}}: !llvm.vec<? x 4 x f32>
-  %1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32>
-  // CHECK: llvm.fsub {{.*}}: !llvm.vec<? x 4 x f32>
-  %2 = arm_sve.subf %1, %d : !arm_sve.vector<4xf32>
-  // CHECK: llvm.fdiv {{.*}}: !llvm.vec<? x 4 x f32>
-  %3 = arm_sve.divf %2, %e : !arm_sve.vector<4xf32>
-  return %3 : !arm_sve.vector<4xf32>
+func @arm_sve_arithf_masked(%a: vector<[4]xf32>,
+                            %b: vector<[4]xf32>,
+                            %c: vector<[4]xf32>,
+                            %d: vector<[4]xf32>,
+                            %e: vector<[4]xf32>,
+                            %mask: vector<[4]xi1>
+                            ) -> vector<[4]xf32> {
+  // CHECK: arm_sve.intr.fadd{{.*}}: (vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> vector<[4]xf32>
+  %0 = arm_sve.masked.addf %mask, %a, %b : vector<[4]xi1>,
+                                           vector<[4]xf32>
+  // CHECK: arm_sve.intr.fsub{{.*}}: (vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> vector<[4]xf32>
+  %1 = arm_sve.masked.subf %mask, %0, %c : vector<[4]xi1>,
+                                           vector<[4]xf32>
+  // CHECK: arm_sve.intr.fmul{{.*}}: (vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> vector<[4]xf32>
+  %2 = arm_sve.masked.mulf %mask, %1, %d : vector<[4]xi1>,
+                                           vector<[4]xf32>
+  // CHECK: arm_sve.intr.fdiv{{.*}}: (vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> vector<[4]xf32>
+  %3 = arm_sve.masked.divf %mask, %2, %e : vector<[4]xi1>,
+                                           vector<[4]xf32>
+  return %3 : vector<[4]xf32>
 }
 
-func @arm_sve_arithi_masked(%a: !arm_sve.vector<4xi32>,
-                            %b: !arm_sve.vector<4xi32>,
-                            %c: !arm_sve.vector<4xi32>,
-                            %d: !arm_sve.vector<4xi32>,
-                            %e: !arm_sve.vector<4xi32>,
-                            %mask: !arm_sve.vector<4xi1>
-                            ) -> !arm_sve.vector<4xi32> {
-  // CHECK: arm_sve.intr.add{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
-  %0 = arm_sve.masked.addi %mask, %a, %b : !arm_sve.vector<4xi1>,
-                                           !arm_sve.vector<4xi32>
-  // CHECK: arm_sve.intr.sub{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
-  %1 = arm_sve.masked.subi %mask, %0, %c : !arm_sve.vector<4xi1>,
-                                           !arm_sve.vector<4xi32>
-  // CHECK: arm_sve.intr.mul{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
-  %2 = arm_sve.masked.muli %mask, %1, %d : !arm_sve.vector<4xi1>,
-                                           !arm_sve.vector<4xi32>
-  // CHECK: arm_sve.intr.sdiv{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
-  %3 = arm_sve.masked.divi_signed %mask, %2, %e : !arm_sve.vector<4xi1>,
-                                                  !arm_sve.vector<4xi32>
-  // CHECK: arm_sve.intr.udiv{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
-  %4 = arm_sve.masked.divi_unsigned %mask, %3, %e : !arm_sve.vector<4xi1>,
-                                                    !arm_sve.vector<4xi32>
-  return %4 : !arm_sve.vector<4xi32>
-}
-
-func @arm_sve_arithf_masked(%a: !arm_sve.vector<4xf32>,
-                            %b: !arm_sve.vector<4xf32>,
-                            %c: !arm_sve.vector<4xf32>,
-                            %d: !arm_sve.vector<4xf32>,
-                            %e: !arm_sve.vector<4xf32>,
-                            %mask: !arm_sve.vector<4xi1>
-                            ) -> !arm_sve.vector<4xf32> {
-  // CHECK: arm_sve.intr.fadd{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x f32>, !llvm.vec<? x 4 x f32>) -> !llvm.vec<? x 4 x f32>
-  %0 = arm_sve.masked.addf %mask, %a, %b : !arm_sve.vector<4xi1>,
-                                           !arm_sve.vector<4xf32>
-  // CHECK: arm_sve.intr.fsub{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x f32>, !llvm.vec<? x 4 x f32>) -> !llvm.vec<? x 4 x f32>
-  %1 = arm_sve.masked.subf %mask, %0, %c : !arm_sve.vector<4xi1>,
-                                           !arm_sve.vector<4xf32>
-  // CHECK: arm_sve.intr.fmul{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x f32>, !llvm.vec<? x 4 x f32>) -> !llvm.vec<? x 4 x f32>
-  %2 = arm_sve.masked.mulf %mask, %1, %d : !arm_sve.vector<4xi1>,
-                                           !arm_sve.vector<4xf32>
-  // CHECK: arm_sve.intr.fdiv{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x f32>, !llvm.vec<? x 4 x f32>) -> !llvm.vec<? x 4 x f32>
-  %3 = arm_sve.masked.divf %mask, %2, %e : !arm_sve.vector<4xi1>,
-                                           !arm_sve.vector<4xf32>
-  return %3 : !arm_sve.vector<4xf32>
-}
-
-func @arm_sve_mask_genf(%a: !arm_sve.vector<4xf32>,
-                        %b: !arm_sve.vector<4xf32>)
-                        -> !arm_sve.vector<4xi1> {
-  // CHECK: llvm.fcmp "oeq" {{.*}}: !llvm.vec<? x 4 x f32>
-  %0 = arm_sve.cmpf oeq, %a, %b : !arm_sve.vector<4xf32>
-  return %0 : !arm_sve.vector<4xi1>
-}
-
-func @arm_sve_mask_geni(%a: !arm_sve.vector<4xi32>,
-                        %b: !arm_sve.vector<4xi32>)
-                        -> !arm_sve.vector<4xi1> {
-  // CHECK: llvm.icmp "uge" {{.*}}: !llvm.vec<? x 4 x i32>
-  %0 = arm_sve.cmpi uge, %a, %b : !arm_sve.vector<4xi32>
-  return %0 : !arm_sve.vector<4xi1>
-}
-
-func @arm_sve_abs_diff(%a: !arm_sve.vector<4xi32>,
-                       %b: !arm_sve.vector<4xi32>)
-                       -> !arm_sve.vector<4xi32> {
-  // CHECK: llvm.sub {{.*}}: !llvm.vec<? x 4 x i32>
-  %z = arm_sve.subi %a, %a : !arm_sve.vector<4xi32>
-  // CHECK: llvm.icmp "sge" {{.*}}: !llvm.vec<? x 4 x i32>
-  %agb = arm_sve.cmpi sge, %a, %b : !arm_sve.vector<4xi32>
-  // CHECK: llvm.icmp "slt" {{.*}}: !llvm.vec<? x 4 x i32>
-  %bga = arm_sve.cmpi slt, %a, %b : !arm_sve.vector<4xi32>
-  // CHECK: "arm_sve.intr.sub"{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
-  %0 = arm_sve.masked.subi %agb, %a, %b : !arm_sve.vector<4xi1>,
-                                          !arm_sve.vector<4xi32>
-  // CHECK: "arm_sve.intr.sub"{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
-  %1 = arm_sve.masked.subi %bga, %b, %a : !arm_sve.vector<4xi1>,
-                                          !arm_sve.vector<4xi32>
-  // CHECK: "arm_sve.intr.add"{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
-  %2 = arm_sve.masked.addi %agb, %z, %0 : !arm_sve.vector<4xi1>,
-                                          !arm_sve.vector<4xi32>
-  // CHECK: "arm_sve.intr.add"{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
-  %3 = arm_sve.masked.addi %bga, %2, %1 : !arm_sve.vector<4xi1>,
-                                          !arm_sve.vector<4xi32>
-  return %3 : !arm_sve.vector<4xi32>
+func @arm_sve_abs_diff(%a: vector<[4]xi32>,
+                       %b: vector<[4]xi32>)
+                       -> vector<[4]xi32> {
+  // CHECK: llvm.mlir.constant(dense<0> : vector<[4]xi32>) : vector<[4]xi32>
+  %z = arith.subi %a, %a : vector<[4]xi32>
+  // CHECK: llvm.icmp "sge" {{.*}}: vector<[4]xi32>
+  %agb = arith.cmpi sge, %a, %b : vector<[4]xi32>
+  // CHECK: llvm.icmp "slt" {{.*}}: vector<[4]xi32>
+  %bga = arith.cmpi slt, %a, %b : vector<[4]xi32>
+  // CHECK: "arm_sve.intr.sub"{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32>
+  %0 = arm_sve.masked.subi %agb, %a, %b : vector<[4]xi1>,
+                                          vector<[4]xi32>
+  // CHECK: "arm_sve.intr.sub"{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32>
+  %1 = arm_sve.masked.subi %bga, %b, %a : vector<[4]xi1>,
+                                          vector<[4]xi32>
+  // CHECK: "arm_sve.intr.add"{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32>
+  %2 = arm_sve.masked.addi %agb, %z, %0 : vector<[4]xi1>,
+                                          vector<[4]xi32>
+  // CHECK: "arm_sve.intr.add"{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32>
+  %3 = arm_sve.masked.addi %bga, %2, %1 : vector<[4]xi1>,
+                                          vector<[4]xi32>
+  return %3 : vector<[4]xi32>
 }
 
 func @get_vector_scale() -> index {
-  // CHECK: arm_sve.vscale
-  %0 = arm_sve.vector_scale : index
+  // CHECK: llvm.intr.vscale
+  %0 = vector.vscale
   return %0 : index
 }
diff --git a/mlir/test/Dialect/ArmSVE/memcpy.mlir b/mlir/test/Dialect/ArmSVE/memcpy.mlir
deleted file mode 100644 (file)
index 93a02ed..0000000
+++ /dev/null
@@ -1,28 +0,0 @@
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" | mlir-opt | FileCheck %s
-
-// CHECK: memcopy([[SRC:%arg[0-9]+]]: memref<?xf32>, [[DST:%arg[0-9]+]]
-func @memcopy(%src : memref<?xf32>, %dst : memref<?xf32>, %size : index) {
-  %c0 = arith.constant 0 : index
-  %c4 = arith.constant 4 : index
-  %vs = arm_sve.vector_scale : index
-  %step = arith.muli %c4, %vs : index
-
-  // CHECK: [[SRCMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[SRC]] : memref<?xf32> to !llvm.struct<(ptr<f32>
-  // CHECK: [[DSTMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[DST]] : memref<?xf32> to !llvm.struct<(ptr<f32>
-  // CHECK: scf.for [[LOOPIDX:%arg[0-9]+]] = {{.*}}
-  scf.for %i0 = %c0 to %size step %step {
-    // CHECK: [[SRCIDX:%[0-9]+]] = builtin.unrealized_conversion_cast [[LOOPIDX]] : index to i64
-    // CHECK: [[SRCMEM:%[0-9]+]] = llvm.extractvalue [[SRCMRS]][1] : !llvm.struct<(ptr<f32>
-    // CHECK-NEXT: [[SRCPTR:%[0-9]+]] = llvm.getelementptr [[SRCMEM]]{{.}}[[SRCIDX]]{{.}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
-    // CHECK-NEXT: [[SRCVPTR:%[0-9]+]] = llvm.bitcast [[SRCPTR]] : !llvm.ptr<f32> to !llvm.ptr<vec<? x 4 x f32>>
-    // CHECK-NEXT: [[LDVAL:%[0-9]+]] = llvm.load [[SRCVPTR]] : !llvm.ptr<vec<? x 4 x f32>>
-    %0 = arm_sve.load %src[%i0] : !arm_sve.vector<4xf32> from memref<?xf32>
-    // CHECK: [[DSTMEM:%[0-9]+]] = llvm.extractvalue [[DSTMRS]][1] : !llvm.struct<(ptr<f32>
-    // CHECK-NEXT: [[DSTPTR:%[0-9]+]] = llvm.getelementptr [[DSTMEM]]{{.}}[[SRCIDX]]{{.}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
-    // CHECK-NEXT: [[DSTVPTR:%[0-9]+]] = llvm.bitcast [[DSTPTR]] : !llvm.ptr<f32> to !llvm.ptr<vec<? x 4 x f32>>
-    // CHECK-NEXT: llvm.store [[LDVAL]], [[DSTVPTR]] : !llvm.ptr<vec<? x 4 x f32>>
-    arm_sve.store %0, %dst[%i0] : !arm_sve.vector<4xf32> to memref<?xf32>
-  }
-
-  return
-}
index 6f4247f..7351c36 100644 (file)
 // RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
 
-func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
-                   %b: !arm_sve.vector<16xi8>,
-                   %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
-  // CHECK: arm_sve.sdot {{.*}}: <16xi8> to <4xi32
+func @arm_sve_sdot(%a: vector<[16]xi8>,
+                   %b: vector<[16]xi8>,
+                   %c: vector<[4]xi32>) -> vector<[4]xi32> {
+  // CHECK: arm_sve.sdot {{.*}}: vector<[16]xi8> to vector<[4]xi32
   %0 = arm_sve.sdot %c, %a, %b :
-             !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
-  return %0 : !arm_sve.vector<4xi32>
+             vector<[16]xi8> to vector<[4]xi32>
+  return %0 : vector<[4]xi32>
 }
 
-func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
-                    %b: !arm_sve.vector<16xi8>,
-                    %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
-  // CHECK: arm_sve.smmla {{.*}}: <16xi8> to <4xi3
+func @arm_sve_smmla(%a: vector<[16]xi8>,
+                    %b: vector<[16]xi8>,
+                    %c: vector<[4]xi32>) -> vector<[4]xi32> {
+  // CHECK: arm_sve.smmla {{.*}}: vector<[16]xi8> to vector<[4]xi3
   %0 = arm_sve.smmla %c, %a, %b :
-             !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
-  return %0 : !arm_sve.vector<4xi32>
+             vector<[16]xi8> to vector<[4]xi32>
+  return %0 : vector<[4]xi32>
 }
 
-func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
-                   %b: !arm_sve.vector<16xi8>,
-                   %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
-  // CHECK: arm_sve.udot {{.*}}: <16xi8> to <4xi32
+func @arm_sve_udot(%a: vector<[16]xi8>,
+                   %b: vector<[16]xi8>,
+                   %c: vector<[4]xi32>) -> vector<[4]xi32> {
+  // CHECK: arm_sve.udot {{.*}}: vector<[16]xi8> to vector<[4]xi32
   %0 = arm_sve.udot %c, %a, %b :
-             !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
-  return %0 : !arm_sve.vector<4xi32>
+             vector<[16]xi8> to vector<[4]xi32>
+  return %0 : vector<[4]xi32>
 }
 
-func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
-                    %b: !arm_sve.vector<16xi8>,
-                    %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
-  // CHECK: arm_sve.ummla {{.*}}: <16xi8> to <4xi3
+func @arm_sve_ummla(%a: vector<[16]xi8>,
+                    %b: vector<[16]xi8>,
+                    %c: vector<[4]xi32>) -> vector<[4]xi32> {
+  // CHECK: arm_sve.ummla {{.*}}: vector<[16]xi8> to vector<[4]xi3
   %0 = arm_sve.ummla %c, %a, %b :
-             !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
-  return %0 : !arm_sve.vector<4xi32>
+             vector<[16]xi8> to vector<[4]xi32>
+  return %0 : vector<[4]xi32>
 }
 
-func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>,
-                     %b: !arm_sve.vector<4xi32>,
-                     %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
-  // CHECK: arm_sve.muli {{.*}}: !arm_sve.vector<4xi32>
-  %0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32>
-  // CHECK: arm_sve.addi {{.*}}: !arm_sve.vector<4xi32>
-  %1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32>
-  return %1 : !arm_sve.vector<4xi32>
-}
-
-func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>,
-                     %b: !arm_sve.vector<4xf32>,
-                     %c: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> {
-  // CHECK: arm_sve.mulf {{.*}}: !arm_sve.vector<4xf32>
-  %0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32>
-  // CHECK: arm_sve.addf {{.*}}: !arm_sve.vector<4xf32>
-  %1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32>
-  return %1 : !arm_sve.vector<4xf32>
-}
-
-func @arm_sve_masked_arithi(%a: !arm_sve.vector<4xi32>,
-                            %b: !arm_sve.vector<4xi32>,
-                            %c: !arm_sve.vector<4xi32>,
-                            %d: !arm_sve.vector<4xi32>,
-                            %e: !arm_sve.vector<4xi32>,
-                            %mask: !arm_sve.vector<4xi1>)
-                            -> !arm_sve.vector<4xi32> {
-  // CHECK: arm_sve.masked.muli {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
-  %0 = arm_sve.masked.muli %mask, %a, %b : !arm_sve.vector<4xi1>,
-                                           !arm_sve.vector<4xi32>
-  // CHECK: arm_sve.masked.addi {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
-  %1 = arm_sve.masked.addi %mask, %0, %c : !arm_sve.vector<4xi1>,
-                                           !arm_sve.vector<4xi32>
-  // CHECK: arm_sve.masked.subi {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
-  %2 = arm_sve.masked.subi %mask, %1, %d : !arm_sve.vector<4xi1>,
-                                           !arm_sve.vector<4xi32>
+func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
+                            %b: vector<[4]xi32>,
+                            %c: vector<[4]xi32>,
+                            %d: vector<[4]xi32>,
+                            %e: vector<[4]xi32>,
+                            %mask: vector<[4]xi1>)
+                            -> vector<[4]xi32> {
+  // CHECK: arm_sve.masked.muli {{.*}}: vector<[4]xi1>, vector<
+  %0 = arm_sve.masked.muli %mask, %a, %b : vector<[4]xi1>,
+                                           vector<[4]xi32>
+  // CHECK: arm_sve.masked.addi {{.*}}: vector<[4]xi1>, vector<
+  %1 = arm_sve.masked.addi %mask, %0, %c : vector<[4]xi1>,
+                                           vector<[4]xi32>
+  // CHECK: arm_sve.masked.subi {{.*}}: vector<[4]xi1>, vector<
+  %2 = arm_sve.masked.subi %mask, %1, %d : vector<[4]xi1>,
+                                           vector<[4]xi32>
   // CHECK: arm_sve.masked.divi_signed
-  %3 = arm_sve.masked.divi_signed %mask, %2, %e : !arm_sve.vector<4xi1>,
-                                                  !arm_sve.vector<4xi32>
+  %3 = arm_sve.masked.divi_signed %mask, %2, %e : vector<[4]xi1>,
+                                                  vector<[4]xi32>
   // CHECK: arm_sve.masked.divi_unsigned
-  %4 = arm_sve.masked.divi_unsigned %mask, %3, %e : !arm_sve.vector<4xi1>,
-                                                    !arm_sve.vector<4xi32>
-  return %2 : !arm_sve.vector<4xi32>
-}
-
-func @arm_sve_masked_arithf(%a: !arm_sve.vector<4xf32>,
-                            %b: !arm_sve.vector<4xf32>,
-                            %c: !arm_sve.vector<4xf32>,
-                            %d: !arm_sve.vector<4xf32>,
-                            %e: !arm_sve.vector<4xf32>,
-                            %mask: !arm_sve.vector<4xi1>)
-                            -> !arm_sve.vector<4xf32> {
-  // CHECK: arm_sve.masked.mulf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
-  %0 = arm_sve.masked.mulf %mask, %a, %b : !arm_sve.vector<4xi1>,
-                                           !arm_sve.vector<4xf32>
-  // CHECK: arm_sve.masked.addf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
-  %1 = arm_sve.masked.addf %mask, %0, %c : !arm_sve.vector<4xi1>,
-                                           !arm_sve.vector<4xf32>
-  // CHECK: arm_sve.masked.subf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
-  %2 = arm_sve.masked.subf %mask, %1, %d : !arm_sve.vector<4xi1>,
-                                           !arm_sve.vector<4xf32>
-  // CHECK: arm_sve.masked.divf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
-  %3 = arm_sve.masked.divf %mask, %2, %e : !arm_sve.vector<4xi1>,
-                                           !arm_sve.vector<4xf32>
-  return %3 : !arm_sve.vector<4xf32>
-}
-
-func @arm_sve_mask_genf(%a: !arm_sve.vector<4xf32>,
-                        %b: !arm_sve.vector<4xf32>)
-                        -> !arm_sve.vector<4xi1> {
-  // CHECK: arm_sve.cmpf oeq, {{.*}}: !arm_sve.vector<4xf32>
-  %0 = arm_sve.cmpf oeq, %a, %b : !arm_sve.vector<4xf32>
-  return %0 : !arm_sve.vector<4xi1>
-}
-
-func @arm_sve_mask_geni(%a: !arm_sve.vector<4xi32>,
-                        %b: !arm_sve.vector<4xi32>)
-                        -> !arm_sve.vector<4xi1> {
-  // CHECK: arm_sve.cmpi uge, {{.*}}: !arm_sve.vector<4xi32>
-  %0 = arm_sve.cmpi uge, %a, %b : !arm_sve.vector<4xi32>
-  return %0 : !arm_sve.vector<4xi1>
-}
-
-func @arm_sve_memory(%v: !arm_sve.vector<4xi32>,
-                     %m: memref<?xi32>)
-                     -> !arm_sve.vector<4xi32> {
-  %c0 = arith.constant 0 : index
-  // CHECK: arm_sve.load {{.*}}: !arm_sve.vector<4xi32> from memref<?xi32>
-  %0 = arm_sve.load %m[%c0] : !arm_sve.vector<4xi32> from memref<?xi32>
-  // CHECK: arm_sve.store {{.*}}: !arm_sve.vector<4xi32> to memref<?xi32>
-  arm_sve.store %v, %m[%c0] : !arm_sve.vector<4xi32> to memref<?xi32>
-  return %0 : !arm_sve.vector<4xi32>
+  %4 = arm_sve.masked.divi_unsigned %mask, %3, %e : vector<[4]xi1>,
+                                                    vector<[4]xi32>
+  return %2 : vector<[4]xi32>
 }
 
-func @get_vector_scale() -> index {
-  // CHECK: arm_sve.vector_scale : index
-  %0 = arm_sve.vector_scale : index
-  return %0 : index
+func @arm_sve_masked_arithf(%a: vector<[4]xf32>,
+                            %b: vector<[4]xf32>,
+                            %c: vector<[4]xf32>,
+                            %d: vector<[4]xf32>,
+                            %e: vector<[4]xf32>,
+                            %mask: vector<[4]xi1>)
+                            -> vector<[4]xf32> {
+  // CHECK: arm_sve.masked.mulf {{.*}}: vector<[4]xi1>, vector<
+  %0 = arm_sve.masked.mulf %mask, %a, %b : vector<[4]xi1>,
+                                           vector<[4]xf32>
+  // CHECK: arm_sve.masked.addf {{.*}}: vector<[4]xi1>, vector<
+  %1 = arm_sve.masked.addf %mask, %0, %c : vector<[4]xi1>,
+                                           vector<[4]xf32>
+  // CHECK: arm_sve.masked.subf {{.*}}: vector<[4]xi1>, vector<
+  %2 = arm_sve.masked.subf %mask, %1, %d : vector<[4]xi1>,
+                                           vector<[4]xf32>
+  // CHECK: arm_sve.masked.divf {{.*}}: vector<[4]xi1>, vector<
+  %3 = arm_sve.masked.divf %mask, %2, %e : vector<[4]xi1>,
+                                           vector<[4]xf32>
+  return %3 : vector<[4]xf32>
 }
index 5525389..f0177c4 100644 (file)
@@ -9,3 +9,11 @@
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// VectorType
+//===----------------------------------------------------------------------===//
+
+// expected-error@+1 {{missing ']' closing set of scalable dimensions}}
+func @scalable_vector_arg(%arg0: vector<[4xf32>) { }
+
+// -----
index e1ef6f1..5e0ea41 100644 (file)
 
 // An unrealized N-1 conversion.
 %result3 = unrealized_conversion_cast %operand, %operand : !foo.type, !foo.type to !bar.tuple_type<!foo.type, !foo.type>
+
+//===----------------------------------------------------------------------===//
+// VectorType
+//===----------------------------------------------------------------------===//
+
+// A basic 1D scalable vector
+%scalable_vector_1d = "foo.op"() : () -> vector<[4]xi32>
+
+// A 2D scalable vector
+%scalable_vector_2d = "foo.op"() : () -> vector<[2x2]xf64>
+
+// A 2D scalable vector with fixed-length dimensions
+%scalable_vector_2d_mixed = "foo.op"() : () -> vector<2x[4]xbf16>
+
+// A multi-dimensional vector with mixed scalable and fixed-length dimensions
+%scalable_vector_multi_mixed = "foo.op"() : () -> vector<2x2x[4x4]xi8>
index a68d105..a77d7ba 100644 (file)
@@ -578,6 +578,25 @@ func @vector_load_and_store_1d_vector_memref(%memref : memref<200x100xvector<8xf
   return
 }
 
+// CHECK-LABEL: @vector_load_and_store_scalable_vector_memref
+func @vector_load_and_store_scalable_vector_memref(%v: vector<[4]xi32>, %m: memref<?xi32>) -> vector<[4]xi32> {
+  %c0 = arith.constant 0 : index
+  // CHECK: vector.load {{.*}}: memref<?xi32>, vector<[4]xi32>
+  %0 = vector.load %m[%c0] : memref<?xi32>, vector<[4]xi32>
+  // CHECK: vector.store {{.*}}: memref<?xi32>, vector<[4]xi32>
+  vector.store %v, %m[%c0] : memref<?xi32>, vector<[4]xi32>
+  return %0 : vector<[4]xi32>
+}
+
+func @vector_load_and_store_1d_scalable_vector_memref(%memref : memref<200x100xvector<8xf32>>,
+                                                      %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xvector<8xf32>>, vector<8xf32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xvector<8xf32>>, vector<8xf32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
+  return
+}
+
 // CHECK-LABEL: @vector_load_and_store_out_of_bounds
 func @vector_load_and_store_out_of_bounds(%memref : memref<7xf32>) {
   %c0 = arith.constant 0 : index
@@ -691,3 +710,10 @@ func @multi_reduction(%0: vector<4x8x16x32xf32>) -> f32 {
     vector<4x16xf32> to f32
   return %2 : f32
 }
+
+// CHECK-LABEL: @get_vector_scale
+func @get_vector_scale() -> index {
+  // CHECK: vector.vscale
+  %0 = vector.vscale
+  return %0 : index
+}
diff --git a/mlir/test/Dialect/Vector/vector-scalable-memcpy.mlir b/mlir/test/Dialect/Vector/vector-scalable-memcpy.mlir
new file mode 100644 (file)
index 0000000..d6d12bc
--- /dev/null
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm | mlir-opt | FileCheck %s
+
+// CHECK: vector_scalable_memcopy([[SRC:%arg[0-9]+]]: memref<?xf32>, [[DST:%arg[0-9]+]]
+func @vector_scalable_memcopy(%src : memref<?xf32>, %dst : memref<?xf32>, %size : index) {
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %vs = vector.vscale
+  %step = arith.muli %c4, %vs : index
+  // CHECK: [[SRCMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[SRC]] : memref<?xf32> to !llvm.struct<(ptr<f32>
+  // CHECK: [[DSTMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[DST]] : memref<?xf32> to !llvm.struct<(ptr<f32>
+  // CHECK: scf.for [[LOOPIDX:%arg[0-9]+]] = {{.*}}
+  scf.for %i0 = %c0 to %size step %step {
+    // CHECK: [[DATAIDX:%[0-9]+]] = builtin.unrealized_conversion_cast [[LOOPIDX]] : index to i64
+    // CHECK: [[SRCMEM:%[0-9]+]] = llvm.extractvalue [[SRCMRS]][1] : !llvm.struct<(ptr<f32>
+    // CHECK-NEXT: [[SRCPTR:%[0-9]+]] = llvm.getelementptr [[SRCMEM]]{{.}}[[DATAIDX]]{{.}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+    // CHECK-NEXT: [[SRCVPTR:%[0-9]+]] = llvm.bitcast [[SRCPTR]] : !llvm.ptr<f32> to !llvm.ptr<vector<[4]xf32>>
+    // CHECK-NEXT: [[LDVAL:%[0-9]+]] = llvm.load [[SRCVPTR]]{{.*}}: !llvm.ptr<vector<[4]xf32>>
+    %0 = vector.load %src[%i0] : memref<?xf32>, vector<[4]xf32>
+    // CHECK: [[DSTMEM:%[0-9]+]] = llvm.extractvalue [[DSTMRS]][1] : !llvm.struct<(ptr<f32>
+    // CHECK-NEXT: [[DSTPTR:%[0-9]+]] = llvm.getelementptr [[DSTMEM]]{{.}}[[DATAIDX]]{{.}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+    // CHECK-NEXT: [[DSTVPTR:%[0-9]+]] = llvm.bitcast [[DSTPTR]] : !llvm.ptr<f32> to !llvm.ptr<vector<[4]xf32>>
+    // CHECK-NEXT: llvm.store [[LDVAL]], [[DSTVPTR]]{{.*}}: !llvm.ptr<vector<[4]xf32>>
+    vector.store %0, %dst[%i0] : memref<?xf32>, vector<[4]xf32>
+  }
+
+  return
+}
index 5dcec26..07adf90 100644 (file)
 // RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
 
 // CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_sdot
-llvm.func @arm_sve_sdot(%arg0: !llvm.vec<?x16 x i8>,
-                        %arg1: !llvm.vec<?x16 x i8>,
-                        %arg2: !llvm.vec<?x4 x i32>)
-                        -> !llvm.vec<?x4 x i32> {
+llvm.func @arm_sve_sdot(%arg0: vector<[16]xi8>,
+                        %arg1: vector<[16]xi8>,
+                        %arg2: vector<[4]xi32>)
+                        -> vector<[4]xi32> {
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sdot.nxv4i32(<vscale x 4
   %0 = "arm_sve.intr.sdot"(%arg2, %arg0, %arg1) :
-    (!llvm.vec<?x4 x i32>, !llvm.vec<?x16 x i8>, !llvm.vec<?x16 x i8>)
-        -> !llvm.vec<?x4 x i32>
-  llvm.return %0 : !llvm.vec<?x4 x i32>
+    (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
+        -> vector<[4]xi32>
+  llvm.return %0 : vector<[4]xi32>
 }
 
 // CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_smmla
-llvm.func @arm_sve_smmla(%arg0: !llvm.vec<?x16 x i8>,
-                         %arg1: !llvm.vec<?x16 x i8>,
-                         %arg2: !llvm.vec<?x4 x i32>)
-                         -> !llvm.vec<?x4 x i32> {
+llvm.func @arm_sve_smmla(%arg0: vector<[16]xi8>,
+                         %arg1: vector<[16]xi8>,
+                         %arg2: vector<[4]xi32>)
+                         -> vector<[4]xi32> {
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.smmla.nxv4i32(<vscale x 4
   %0 = "arm_sve.intr.smmla"(%arg2, %arg0, %arg1) :
-    (!llvm.vec<?x4 x i32>, !llvm.vec<?x16 x i8>, !llvm.vec<?x16 x i8>)
-        -> !llvm.vec<?x4 x i32>
-  llvm.return %0 : !llvm.vec<?x4 x i32>
+    (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
+        -> vector<[4]xi32>
+  llvm.return %0 : vector<[4]xi32>
 }
 
 // CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_udot
-llvm.func @arm_sve_udot(%arg0: !llvm.vec<?x16 x i8>,
-                        %arg1: !llvm.vec<?x16 x i8>,
-                        %arg2: !llvm.vec<?x4 x i32>)
-                        -> !llvm.vec<?x4 x i32> {
+llvm.func @arm_sve_udot(%arg0: vector<[16]xi8>,
+                        %arg1: vector<[16]xi8>,
+                        %arg2: vector<[4]xi32>)
+                        -> vector<[4]xi32> {
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.udot.nxv4i32(<vscale x 4
   %0 = "arm_sve.intr.udot"(%arg2, %arg0, %arg1) :
-    (!llvm.vec<?x4 x i32>, !llvm.vec<?x16 x i8>, !llvm.vec<?x16 x i8>)
-        -> !llvm.vec<?x4 x i32>
-  llvm.return %0 : !llvm.vec<?x4 x i32>
+    (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
+        -> vector<[4]xi32>
+  llvm.return %0 : vector<[4]xi32>
 }
 
 // CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_ummla
-llvm.func @arm_sve_ummla(%arg0: !llvm.vec<?x16 x i8>,
-                         %arg1: !llvm.vec<?x16 x i8>,
-                         %arg2: !llvm.vec<?x4 x i32>)
-                         -> !llvm.vec<?x4 x i32> {
+llvm.func @arm_sve_ummla(%arg0: vector<[16]xi8>,
+                         %arg1: vector<[16]xi8>,
+                         %arg2: vector<[4]xi32>)
+                         -> vector<[4]xi32> {
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.ummla.nxv4i32(<vscale x 4
   %0 = "arm_sve.intr.ummla"(%arg2, %arg0, %arg1) :
-    (!llvm.vec<?x4 x i32>, !llvm.vec<?x16 x i8>, !llvm.vec<?x16 x i8>)
-        -> !llvm.vec<?x4 x i32>
-  llvm.return %0 : !llvm.vec<?x4 x i32>
+    (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
+        -> vector<[4]xi32>
+  llvm.return %0 : vector<[4]xi32>
 }
 
 // CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
-llvm.func @arm_sve_arithi(%arg0: !llvm.vec<? x 4 x i32>,
-                          %arg1: !llvm.vec<? x 4 x i32>,
-                          %arg2: !llvm.vec<? x 4 x i32>)
-                          -> !llvm.vec<? x 4 x i32> {
+llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
+                          %arg1: vector<[4]xi32>,
+                          %arg2: vector<[4]xi32>)
+                          -> vector<[4]xi32> {
   // CHECK: mul <vscale x 4 x i32>
-  %0 = llvm.mul %arg0, %arg1 : !llvm.vec<? x 4 x i32>
+  %0 = llvm.mul %arg0, %arg1 : vector<[4]xi32>
   // CHECK: add <vscale x 4 x i32>
-  %1 = llvm.add %0, %arg2 : !llvm.vec<? x 4 x i32>
-  llvm.return %1 : !llvm.vec<? x 4 x i32>
+  %1 = llvm.add %0, %arg2 : vector<[4]xi32>
+  llvm.return %1 : vector<[4]xi32>
 }
 
 // CHECK-LABEL: define <vscale x 4 x float> @arm_sve_arithf
-llvm.func @arm_sve_arithf(%arg0: !llvm.vec<? x 4 x f32>,
-                          %arg1: !llvm.vec<? x 4 x f32>,
-                          %arg2: !llvm.vec<? x 4 x f32>)
-                          -> !llvm.vec<? x 4 x f32> {
+llvm.func @arm_sve_arithf(%arg0: vector<[4]xf32>,
+                          %arg1: vector<[4]xf32>,
+                          %arg2: vector<[4]xf32>)
+                          -> vector<[4]xf32> {
   // CHECK: fmul <vscale x 4 x float>
-  %0 = llvm.fmul %arg0, %arg1 : !llvm.vec<? x 4 x f32>
+  %0 = llvm.fmul %arg0, %arg1 : vector<[4]xf32>
   // CHECK: fadd <vscale x 4 x float>
-  %1 = llvm.fadd %0, %arg2 : !llvm.vec<? x 4 x f32>
-  llvm.return %1 : !llvm.vec<? x 4 x f32>
+  %1 = llvm.fadd %0, %arg2 : vector<[4]xf32>
+  llvm.return %1 : vector<[4]xf32>
 }
 
 // CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi_masked
-llvm.func @arm_sve_arithi_masked(%arg0: !llvm.vec<? x 4 x i32>,
-                                 %arg1: !llvm.vec<? x 4 x i32>,
-                                 %arg2: !llvm.vec<? x 4 x i32>,
-                                 %arg3: !llvm.vec<? x 4 x i32>,
-                                 %arg4: !llvm.vec<? x 4 x i32>,
-                                 %arg5: !llvm.vec<? x 4 x i1>)
-                                 -> !llvm.vec<? x 4 x i32> {
+llvm.func @arm_sve_arithi_masked(%arg0: vector<[4]xi32>,
+                                 %arg1: vector<[4]xi32>,
+                                 %arg2: vector<[4]xi32>,
+                                 %arg3: vector<[4]xi32>,
+                                 %arg4: vector<[4]xi32>,
+                                 %arg5: vector<[4]xi1>)
+                                 -> vector<[4]xi32> {
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.add.nxv4i32
-  %0 = "arm_sve.intr.add"(%arg5, %arg0, %arg1) : (!llvm.vec<? x 4 x i1>,
-                                                  !llvm.vec<? x 4 x i32>,
-                                                  !llvm.vec<? x 4 x i32>)
-                                                  -> !llvm.vec<? x 4 x i32>
+  %0 = "arm_sve.intr.add"(%arg5, %arg0, %arg1) : (vector<[4]xi1>,
+                                                  vector<[4]xi32>,
+                                                  vector<[4]xi32>)
+                                                  -> vector<[4]xi32>
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sub.nxv4i32
-  %1 = "arm_sve.intr.sub"(%arg5, %0, %arg1) : (!llvm.vec<? x 4 x i1>,
-                                               !llvm.vec<? x 4 x i32>,
-                                               !llvm.vec<? x 4 x i32>)
-                                               -> !llvm.vec<? x 4 x i32>
+  %1 = "arm_sve.intr.sub"(%arg5, %0, %arg1) : (vector<[4]xi1>,
+                                               vector<[4]xi32>,
+                                               vector<[4]xi32>)
+                                               -> vector<[4]xi32>
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.mul.nxv4i32
-  %2 = "arm_sve.intr.mul"(%arg5, %1, %arg3) : (!llvm.vec<? x 4 x i1>,
-                                               !llvm.vec<? x 4 x i32>,
-                                               !llvm.vec<? x 4 x i32>)
-                                               -> !llvm.vec<? x 4 x i32>
+  %2 = "arm_sve.intr.mul"(%arg5, %1, %arg3) : (vector<[4]xi1>,
+                                               vector<[4]xi32>,
+                                               vector<[4]xi32>)
+                                               -> vector<[4]xi32>
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sdiv.nxv4i32
-  %3 = "arm_sve.intr.sdiv"(%arg5, %2, %arg4) : (!llvm.vec<? x 4 x i1>,
-                                               !llvm.vec<? x 4 x i32>,
-                                               !llvm.vec<? x 4 x i32>)
-                                               -> !llvm.vec<? x 4 x i32>
+  %3 = "arm_sve.intr.sdiv"(%arg5, %2, %arg4) : (vector<[4]xi1>,
+                                               vector<[4]xi32>,
+                                               vector<[4]xi32>)
+                                               -> vector<[4]xi32>
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.udiv.nxv4i32
-  %4 = "arm_sve.intr.udiv"(%arg5, %3, %arg4) : (!llvm.vec<? x 4 x i1>,
-                                               !llvm.vec<? x 4 x i32>,
-                                               !llvm.vec<? x 4 x i32>)
-                                               -> !llvm.vec<? x 4 x i32>
-  llvm.return %4 : !llvm.vec<? x 4 x i32>
+  %4 = "arm_sve.intr.udiv"(%arg5, %3, %arg4) : (vector<[4]xi1>,
+                                               vector<[4]xi32>,
+                                               vector<[4]xi32>)
+                                               -> vector<[4]xi32>
+  llvm.return %4 : vector<[4]xi32>
 }
 
 // CHECK-LABEL: define <vscale x 4 x float> @arm_sve_arithf_masked
-llvm.func @arm_sve_arithf_masked(%arg0: !llvm.vec<? x 4 x f32>,
-                                 %arg1: !llvm.vec<? x 4 x f32>,
-                                 %arg2: !llvm.vec<? x 4 x f32>,
-                                 %arg3: !llvm.vec<? x 4 x f32>,
-                                 %arg4: !llvm.vec<? x 4 x f32>,
-                                 %arg5: !llvm.vec<? x 4 x i1>)
-                                 -> !llvm.vec<? x 4 x f32> {
+llvm.func @arm_sve_arithf_masked(%arg0: vector<[4]xf32>,
+                                 %arg1: vector<[4]xf32>,
+                                 %arg2: vector<[4]xf32>,
+                                 %arg3: vector<[4]xf32>,
+                                 %arg4: vector<[4]xf32>,
+                                 %arg5: vector<[4]xi1>)
+                                 -> vector<[4]xf32> {
   // CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.fadd.nxv4f32
-  %0 = "arm_sve.intr.fadd"(%arg5, %arg0, %arg1) : (!llvm.vec<? x 4 x i1>,
-                                                   !llvm.vec<? x 4 x f32>,
-                                                   !llvm.vec<? x 4 x f32>)
-                                                   -> !llvm.vec<? x 4 x f32>
+  %0 = "arm_sve.intr.fadd"(%arg5, %arg0, %arg1) : (vector<[4]xi1>,
+                                                   vector<[4]xf32>,
+                                                   vector<[4]xf32>)
+                                                   -> vector<[4]xf32>
   // CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.fsub.nxv4f32
-  %1 = "arm_sve.intr.fsub"(%arg5, %0, %arg2) : (!llvm.vec<? x 4 x i1>,
-                                                !llvm.vec<? x 4 x f32>,
-                                                !llvm.vec<? x 4 x f32>)
-                                                -> !llvm.vec<? x 4 x f32>
+  %1 = "arm_sve.intr.fsub"(%arg5, %0, %arg2) : (vector<[4]xi1>,
+                                                vector<[4]xf32>,
+                                                vector<[4]xf32>)
+                                                -> vector<[4]xf32>
   // CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.fmul.nxv4f32
-  %2 = "arm_sve.intr.fmul"(%arg5, %1, %arg3) : (!llvm.vec<? x 4 x i1>,
-                                                !llvm.vec<? x 4 x f32>,
-                                                !llvm.vec<? x 4 x f32>)
-                                                -> !llvm.vec<? x 4 x f32>
+  %2 = "arm_sve.intr.fmul"(%arg5, %1, %arg3) : (vector<[4]xi1>,
+                                                vector<[4]xf32>,
+                                                vector<[4]xf32>)
+                                                -> vector<[4]xf32>
   // CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.fdiv.nxv4f32
-  %3 = "arm_sve.intr.fdiv"(%arg5, %2, %arg4) : (!llvm.vec<? x 4 x i1>,
-                                                !llvm.vec<? x 4 x f32>,
-                                                !llvm.vec<? x 4 x f32>)
-                                                -> !llvm.vec<? x 4 x f32>
-  llvm.return %3 : !llvm.vec<? x 4 x f32>
+  %3 = "arm_sve.intr.fdiv"(%arg5, %2, %arg4) : (vector<[4]xi1>,
+                                                vector<[4]xf32>,
+                                                vector<[4]xf32>)
+                                                -> vector<[4]xf32>
+  llvm.return %3 : vector<[4]xf32>
 }
 
 // CHECK-LABEL: define <vscale x 4 x i1> @arm_sve_mask_genf
-llvm.func @arm_sve_mask_genf(%arg0: !llvm.vec<? x 4 x f32>,
-                             %arg1: !llvm.vec<? x 4 x f32>)
-                             -> !llvm.vec<? x 4 x i1> {
+llvm.func @arm_sve_mask_genf(%arg0: vector<[4]xf32>,
+                             %arg1: vector<[4]xf32>)
+                             -> vector<[4]xi1> {
   // CHECK: fcmp oeq <vscale x 4 x float>
-  %0 = llvm.fcmp "oeq" %arg0, %arg1 : !llvm.vec<? x 4 x f32>
-  llvm.return %0 : !llvm.vec<? x 4 x i1>
+  %0 = llvm.fcmp "oeq" %arg0, %arg1 : vector<[4]xf32>
+  llvm.return %0 : vector<[4]xi1>
 }
 
 // CHECK-LABEL: define <vscale x 4 x i1> @arm_sve_mask_geni
-llvm.func @arm_sve_mask_geni(%arg0: !llvm.vec<? x 4 x i32>,
-                             %arg1: !llvm.vec<? x 4 x i32>)
-                             -> !llvm.vec<? x 4 x i1> {
+llvm.func @arm_sve_mask_geni(%arg0: vector<[4]xi32>,
+                             %arg1: vector<[4]xi32>)
+                             -> vector<[4]xi1> {
   // CHECK: icmp uge <vscale x 4 x i32>
-  %0 = llvm.icmp "uge" %arg0, %arg1 : !llvm.vec<? x 4 x i32>
-  llvm.return %0 : !llvm.vec<? x 4 x i1>
+  %0 = llvm.icmp "uge" %arg0, %arg1 : vector<[4]xi32>
+  llvm.return %0 : vector<[4]xi1>
 }
 
 // CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_abs_diff
-llvm.func @arm_sve_abs_diff(%arg0: !llvm.vec<? x 4 x i32>,
-                            %arg1: !llvm.vec<? x 4 x i32>)
-                            -> !llvm.vec<? x 4 x i32> {
+llvm.func @arm_sve_abs_diff(%arg0: vector<[4]xi32>,
+                            %arg1: vector<[4]xi32>)
+                            -> vector<[4]xi32> {
   // CHECK: sub <vscale x 4 x i32>
-  %0 = llvm.sub %arg0, %arg0  : !llvm.vec<? x 4 x i32>
+  %0 = llvm.sub %arg0, %arg0  : vector<[4]xi32>
   // CHECK: icmp sge <vscale x 4 x i32>
-  %1 = llvm.icmp "sge" %arg0, %arg1 : !llvm.vec<? x 4 x i32>
+  %1 = llvm.icmp "sge" %arg0, %arg1 : vector<[4]xi32>
   // CHECK: icmp slt <vscale x 4 x i32>
-  %2 = llvm.icmp "slt" %arg0, %arg1 : !llvm.vec<? x 4 x i32>
+  %2 = llvm.icmp "slt" %arg0, %arg1 : vector<[4]xi32>
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sub.nxv4i32
-  %3 = "arm_sve.intr.sub"(%1, %arg0, %arg1) : (!llvm.vec<? x 4 x i1>,
-                                               !llvm.vec<? x 4 x i32>,
-                                               !llvm.vec<? x 4 x i32>)
-                                               -> !llvm.vec<? x 4 x i32>
+  %3 = "arm_sve.intr.sub"(%1, %arg0, %arg1) : (vector<[4]xi1>,
+                                               vector<[4]xi32>,
+                                               vector<[4]xi32>)
+                                               -> vector<[4]xi32>
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sub.nxv4i32
-  %4 = "arm_sve.intr.sub"(%2, %arg1, %arg0) : (!llvm.vec<? x 4 x i1>,
-                                               !llvm.vec<? x 4 x i32>,
-                                               !llvm.vec<? x 4 x i32>)
-                                               -> !llvm.vec<? x 4 x i32>
+  %4 = "arm_sve.intr.sub"(%2, %arg1, %arg0) : (vector<[4]xi1>,
+                                               vector<[4]xi32>,
+                                               vector<[4]xi32>)
+                                               -> vector<[4]xi32>
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.add.nxv4i32
-  %5 = "arm_sve.intr.add"(%1, %0, %3) : (!llvm.vec<? x 4 x i1>,
-                                         !llvm.vec<? x 4 x i32>,
-                                         !llvm.vec<? x 4 x i32>)
-                                         -> !llvm.vec<? x 4 x i32>
+  %5 = "arm_sve.intr.add"(%1, %0, %3) : (vector<[4]xi1>,
+                                         vector<[4]xi32>,
+                                         vector<[4]xi32>)
+                                         -> vector<[4]xi32>
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.add.nxv4i32
-  %6 = "arm_sve.intr.add"(%2, %5, %4) : (!llvm.vec<? x 4 x i1>,
-                                         !llvm.vec<? x 4 x i32>,
-                                         !llvm.vec<? x 4 x i32>)
-                                         -> !llvm.vec<? x 4 x i32>
-  llvm.return %6 : !llvm.vec<? x 4 x i32>
+  %6 = "arm_sve.intr.add"(%2, %5, %4) : (vector<[4]xi1>,
+                                         vector<[4]xi32>,
+                                         vector<[4]xi32>)
+                                         -> vector<[4]xi32>
+  llvm.return %6 : vector<[4]xi32>
 }
 
 // CHECK-LABEL: define void @memcopy
@@ -234,7 +234,7 @@ llvm.func @memcopy(%arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>,
   %12 = llvm.mlir.constant(0 : index) : i64
   %13 = llvm.mlir.constant(4 : index) : i64
   // CHECK: [[VL:%[0-9]+]] = call i64 @llvm.vscale.i64()
-  %14 = "arm_sve.vscale"() : () -> i64
+  %14 = "llvm.intr.vscale"() : () -> i64
   // CHECK: mul i64 [[VL]], 4
   %15 = llvm.mul %14, %13  : i64
   llvm.br ^bb1(%12 : i64)
@@ -249,9 +249,9 @@ llvm.func @memcopy(%arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>,
   // CHECK: etelementptr float, float*
   %19 = llvm.getelementptr %18[%16] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
   // CHECK: bitcast float* %{{[0-9]+}} to <vscale x 4 x float>*
-  %20 = llvm.bitcast %19 : !llvm.ptr<f32> to !llvm.ptr<vec<? x 4 x f32>>
+  %20 = llvm.bitcast %19 : !llvm.ptr<f32> to !llvm.ptr<vector<[4]xf32>>
   // CHECK: load <vscale x 4 x float>, <vscale x 4 x float>*
-  %21 = llvm.load %20 : !llvm.ptr<vec<? x 4 x f32>>
+  %21 = llvm.load %20 : !llvm.ptr<vector<[4]xf32>>
   // CHECK: extractvalue { float*, float*, i64, [1 x i64], [1 x i64] }
   %22 = llvm.extractvalue %11[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
                                                  array<1 x i64>,
@@ -259,9 +259,9 @@ llvm.func @memcopy(%arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>,
   // CHECK: getelementptr float, float* %32
   %23 = llvm.getelementptr %22[%16] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
   // CHECK: bitcast float* %33 to <vscale x 4 x float>*
-  %24 = llvm.bitcast %23 : !llvm.ptr<f32> to !llvm.ptr<vec<? x 4 x f32>>
+  %24 = llvm.bitcast %23 : !llvm.ptr<f32> to !llvm.ptr<vector<[4]xf32>>
   // CHECK: store <vscale x 4 x float> %{{[0-9]+}}, <vscale x 4 x float>* %{{[0-9]+}}
-  llvm.store %21, %24 : !llvm.ptr<vec<? x 4 x f32>>
+  llvm.store %21, %24 : !llvm.ptr<vector<[4]xf32>>
   %25 = llvm.add %16, %15  : i64
   llvm.br ^bb1(%25 : i64)
 ^bb3:
@@ -271,6 +271,6 @@ llvm.func @memcopy(%arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>,
 // CHECK-LABEL: define i64 @get_vector_scale()
 llvm.func @get_vector_scale() -> i64 {
   // CHECK: call i64 @llvm.vscale.i64()
-  %0 = "arm_sve.vscale"() : () -> i64
+  %0 = "llvm.intr.vscale"() : () -> i64
   llvm.return %0 : i64
 }