This commit adds `TensorRankOf<types, typeNames, ranks>` to specify ranked
tensor types with the specified types and ranks. For example,
`TensorRankOf<[I32, F32], ["i32", "F32"], [0, 1]>` matches `tensor<i32>`,
`tensor<?xi32>`, `tensor<f32>`, or `tensor<?xf32>`.
PiperOrigin-RevId:
266461256
def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
-// Whether a type is a ranked tensor type.
-def HasRankPred : CPred<"$_self.cast<ShapedType>().hasRank()">;
-
-// Whether a type is a ranked tensor type with one of the specified ranks.
-class HasAnyRankOfPred<list<int> ranks> : And<[
- HasRankPred,
- Or<!foreach(rank, ranks,
- CPred<"$_self.cast<ShapedType>().getRank() == " # rank>)>]>;
-
def I1Tensor : TensorOf<[I1]>;
def I8Tensor : TensorOf<[I8]>;
def I16Tensor : TensorOf<[I16]>;
def F32Tensor : TensorOf<[F32]>;
def F64Tensor : TensorOf<[F64]>;
+// Whether a type is a ranked tensor type.
+def HasRankPred : CPred<"$_self.cast<ShapedType>().hasRank()">;
+
+// Whether a type is a ranked tensor type with one of the specified ranks.
+class HasAnyRankOfPred<list<int> ranks> : And<[
+ HasRankPred,
+ Or<!foreach(rank, ranks,
+ CPred<"$_self.cast<ShapedType>().getRank() == " # rank>)>]>;
+
+// Ranked tensor type with one of the specified types and ranks.
+class TensorRankOf<list<Type> allowedTypes, list<int> ranks> :
+ Type<And<[TensorOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
+ StrJoin<!foreach(rank, ranks, rank # "D"), "/">.result # " " #
+ TensorOf<allowedTypes>.description>;
+
+class 0DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [0]>;
+class 1DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [1]>;
+class 2DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [2]>;
+class 3DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [3]>;
+class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>;
+
// Memref type.
// Memrefs are blocks of data with fixed type and rank.
let arguments = (ins AnyStaticShapeMemRef:$x);
}
-def I32TensorRank0Or1Op : TEST_Op<"i32_tensor_rank_0_or_1"> {
+def NDTensorOfOp : TEST_Op<"nd_tensor_of"> {
let arguments = (ins
- Type<And<[I32Tensor.predicate, HasAnyRankOfPred<[0, 1]>]>,
- "tensor<i32> or tensor<?xi32>">:$arg0
+ 0DTensorOf<[F32]>:$arg0,
+ 1DTensorOf<[F32]>:$arg1,
+ 2DTensorOf<[I16]>:$arg2,
+ 3DTensorOf<[I16]>:$arg3,
+ 4DTensorOf<[I16]>:$arg4
+ );
+}
+
+def MultiTensorRankOf : TEST_Op<"multi_tensor_rank_of"> {
+ let arguments = (ins
+ TensorRankOf<[I8, I32, F32], [0, 1]>:$arg0
);
}
// -----
-func @tensor_has_rank_0_or_1_success(%arg0: tensor<i32>, %arg1: tensor<5xi32>) {
- "test.i32_tensor_rank_0_or_1"(%arg0) : (tensor<i32>) -> ()
- "test.i32_tensor_rank_0_or_1"(%arg1) : (tensor<5xi32>) -> ()
+func @nd_tensor_of_success(%arg0: tensor<f32>, %arg1: tensor<10xf32>, %arg2: tensor<20x30xi16>, %arg3: tensor<40x50x60xi16>, %arg4: tensor<70x80x90x100xi16>) {
+ "test.nd_tensor_of"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<f32>, tensor<10xf32>, tensor<20x30xi16>, tensor<40x50x60xi16>, tensor<70x80x90x100xi16>) -> ()
return
}
// -----
-func @tensor_has_rank_0_or_1_wrong_type(%arg0: tensor<2x2xi32>) {
- // expected-error @+1 {{test.i32_tensor_rank_0_or_1' op operand #0 must be tensor<i32> or tensor<?xi32>}}
- "test.i32_tensor_rank_0_or_1"(%arg0) : (tensor<2x2xi32>) -> ()
+func @nd_tensor_of_success_wrong_type_0d(%arg0: tensor<f32>, %arg1: tensor<10xf32>, %arg2: tensor<20x30xi16>, %arg3: tensor<40x50x60xi16>, %arg4: tensor<70x80x90x100xi32>) {
+ // expected-error @+1 {{'test.nd_tensor_of' op operand #0 must be 0D tensor of 32-bit float values}}
+ "test.nd_tensor_of"(%arg1, %arg1, %arg2, %arg3, %arg4) : (tensor<10xf32>, tensor<10xf32>, tensor<20x30xi16>, tensor<40x50x60xi16>, tensor<70x80x90x100xi32>) -> ()
+ return
+}
+
+// -----
+
+func @nd_tensor_of_success_wrong_type_4d(%arg0: tensor<f32>, %arg1: tensor<10xf32>, %arg2: tensor<20x30xi16>, %arg3: tensor<40x50x60xi16>, %arg4: tensor<70x80x90x100xi32>) {
+ // expected-error @+1 {{'test.nd_tensor_of' op operand #4 must be 4D tensor of 16-bit integer values}}
+ "test.nd_tensor_of"(%arg0, %arg1, %arg2, %arg3, %arg3) : (tensor<f32>, tensor<10xf32>, tensor<20x30xi16>, tensor<40x50x60xi16>, tensor<40x50x60xi16>) -> ()
+ return
+}
+
+// -----
+
+func @multi_tensor_rank_of_success(%arg0: tensor<i8>, %arg1: tensor<i32>, %arg2: tensor<f32>, %arg3: tensor<1xi8>, %arg4: tensor<1xi32>, %arg5: tensor<1xf32>) {
+ "test.multi_tensor_rank_of"(%arg0) : (tensor<i8>) -> ()
+ "test.multi_tensor_rank_of"(%arg1) : (tensor<i32>) -> ()
+ "test.multi_tensor_rank_of"(%arg2) : (tensor<f32>) -> ()
+ "test.multi_tensor_rank_of"(%arg3) : (tensor<1xi8>) -> ()
+ "test.multi_tensor_rank_of"(%arg4) : (tensor<1xi32>) -> ()
+ "test.multi_tensor_rank_of"(%arg5) : (tensor<1xf32>) -> ()
+ return
+}
+
+// -----
+
+func @multi_tensor_rank_of_wrong_unranked_type(%arg0: tensor<2x2xi8>) {
+ // expected-error @+1 {{'test.multi_tensor_rank_of' op operand #0 must be 0D/1D tensor of 8-bit integer or 32-bit integer or 32-bit float values}}
+ "test.multi_tensor_rank_of"(%arg0) : (tensor<2x2xi8>) -> ()
+ return
+}
+
+// -----
+
+func @multi_tensor_rank_of_wrong_element_type(%arg0: tensor<2xi16>) {
+ // expected-error @+1 {{'test.multi_tensor_rank_of' op operand #0 must be 0D/1D tensor of 8-bit integer or 32-bit integer or 32-bit float values}}
+ "test.multi_tensor_rank_of"(%arg0) : (tensor<2xi16>) -> ()
return
}