// Vector types.
-class Vector<Type t> : ShapedContainerType<t, IsVectorTypePred, "vector">;
+class VectorOf<list<Type> allowedTypes, string elementDescription = ""> :
+ ShapedContainerType<AnyTypeOf<allowedTypes, elementDescription>,
+ IsVectorTypePred, "vector">;
-def AnyVector : Vector<AnyType>;
+def AnyVector : VectorOf<[AnyType]>;
// Tensor types.
-class Tensor<Type t> : ShapedContainerType<t, IsTensorTypePred, "tensor">;
-
-def AnyTensor : Tensor<AnyType>;
-
// Any tensor type whose element type is from the given `allowedTypes` list
-class AnyTensorOf<list<Type> allowedTypes, string elementDescription = ""> :
- Tensor<AnyTypeOf<allowedTypes, elementDescription>>;
+class TensorOf<list<Type> allowedTypes, string elementDescription = ""> :
+ ShapedContainerType<AnyTypeOf<allowedTypes, elementDescription>,
+ IsTensorTypePred, "tensor">;
+
+def AnyTensor : TensorOf<[AnyType]>;
// TODO(b/130807343) Fix description to contain element information.
class StaticShapeTensor<Type t>
- : Type<And<[ Tensor<t>.predicate, HasStaticShapePred ]>,
+ : Type<And<[ TensorOf<[t]>.predicate, HasStaticShapePred ]>,
"statically shaped tensor">;
def AnyStaticShapeTensor : StaticShapeTensor<AnyType>;
-def I1Tensor : Tensor<I1>;
-def I8Tensor : Tensor<I8>;
-def I16Tensor : Tensor<I16>;
-def I32Tensor : Tensor<I32>;
-def I64Tensor : Tensor<I64>;
+def I1Tensor : TensorOf<[I1]>;
+def I8Tensor : TensorOf<[I8]>;
+def I16Tensor : TensorOf<[I16]>;
+def I32Tensor : TensorOf<[I32]>;
+def I64Tensor : TensorOf<[I64]>;
+
+def BF16Tensor : TensorOf<[BF16]>;
+def F16Tensor : TensorOf<[F16]>;
+def F32Tensor : TensorOf<[F32]>;
+def F64Tensor : TensorOf<[F64]>;
+
+// Memref type.
+
+// TODO(b/132735995) Use ShapedContainerType when MemRef subclasses ShapedType.
+// Memrefs are blocks of data with fixed type and rank.
+class MemRefOf<list<Type> allowedTypes, string elementDescription = ""> :
+ ContainerType<AnyTypeOf<allowedTypes, elementDescription>, IsMemRefTypePred,
+ "$_self.cast<MemRefType>().getElementType()", "memref">;
+
+def AnyMemRef : MemRefOf<[AnyType]>;
+
+// Memref declarations handle any memref, independent of rank, size, (static or
+// dynamic), layout, or memory space.
+def I1MemRef : MemRefOf<[I1]>;
+def I8MemRef : MemRefOf<[I8]>;
+def I16MemRef : MemRefOf<[I16]>;
+def I32MemRef : MemRefOf<[I32]>;
+def I64MemRef : MemRefOf<[I64]>;
-def BF16Tensor : Tensor<BF16>;
-def F16Tensor : Tensor<F16>;
-def F32Tensor : Tensor<F32>;
-def F64Tensor : Tensor<F64>;
+def BF16MemRef : MemRefOf<[BF16]>;
+def F16MemRef : MemRefOf<[F16]>;
+def F32MemRef : MemRefOf<[F32]>;
+def F64MemRef : MemRefOf<[F64]>;
-// This represents a generic tuple without any constraints on elemental type,
-// ranks, or size. As Tuples can contain tensors, vectors, or scalar values
-// there is not only a single elemental type.
-def Tuple : Type<IsTupleTypePred, "tuple">;
+// This represents a generic tuple without any constraints on element type.
+def AnyTuple : Type<IsTupleTypePred, "tuple">;
+// TODO(b/132952417) Make this accept a list of types like the classes above.
// A Tuple that only holds elements of a certain type. This cannot inherit from
// ContainerType because tuples do not always have a single element type that
// could be retrieved with elementTypeCall.
-class TypedTuple<Type t> :
+class TupleOf<Type t> :
Type<And<[
- Tuple.predicate,
+ IsTupleTypePred,
Concat<
[{
llvm::all_of(
"; })">
]>, "tuple">;
-// Memref type.
-
-// Memrefs are blocks of data with fixed type and rank.
-class MemRef<Type t>
- : ContainerType<t, IsMemRefTypePred,
- "$_self.cast<MemRefType>().getElementType()", "memref">;
-
-// Memref declarations handle any memref, independent of rank, size, (static or
-// dynamic), layout, or memory space.
-def I1MemRef : MemRef<I1>;
-def I8MemRef : MemRef<I8>;
-def I16MemRef : MemRef<I16>;
-def I32MemRef : MemRef<I32>;
-def I64MemRef : MemRef<I64>;
-
-def BF16MemRef : MemRef<BF16>;
-def F16MemRef : MemRef<F16>;
-def F32MemRef : MemRef<F32>;
-def F64MemRef : MemRef<F64>;
-
//===----------------------------------------------------------------------===//
// Common type constraints
//===----------------------------------------------------------------------===//
// Type constraint for integer-like types: integers, indices, vectors of
// integers, tensors of integers.
def IntegerLike : TypeConstraint<Or<[Integer.predicate, Index.predicate,
- Vector<Integer>.predicate, Tensor<Integer>.predicate]>,
+ VectorOf<[Integer]>.predicate, TensorOf<[Integer]>.predicate]>,
"integer-like">;
// Type constraint for float-like types: floats, vectors or tensors thereof.
def FloatLike : TypeConstraint<Or<[Float.predicate,
- Vector<Float>.predicate, Tensor<Float>.predicate]>,
+ VectorOf<[Float]>.predicate, TensorOf<[Float]>.predicate]>,
"floating-point-like">;
}];
let arguments = (ins Variadic<Index>:$value);
- let results = (outs MemRef<AnyType>);
+ let results = (outs AnyMemRef);
let builders = [OpBuilder<
"Builder *builder, OperationState *result, MemRefType memrefType", [{
dealloc %0 : memref<8x64xf32, (d0, d1) -> (d0, d1), 1>
}];
- let arguments = (ins MemRef<AnyType>:$memref);
+ let arguments = (ins AnyMemRef:$memref);
let hasCanonicalizer = 1;
}
%1 = dim %0, 2 : tensor<?x?x?xf32>
}];
- let arguments = (ins AnyTypeOf<[MemRef<AnyType>, AnyTensor],
+ let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor],
"any tensor or memref type">:$memrefOrTensor,
APIntAttr:$index);
let results = (outs Index);
%3 = memref_cast %1 : memref<4xf32> to memref<?xf32>
}];
- let arguments = (ins MemRef<AnyType>);
- let results = (outs MemRef<AnyType>);
+ let arguments = (ins AnyMemRef);
+ let results = (outs AnyMemRef);
let extraClassDeclaration = [{
/// Return true if `a` and `b` are valid operand and result pairs for