Add TensorRankOf for ranked tensor types with specific ranks
authorLogan Chien <loganchien@google.com>
Fri, 30 Aug 2019 21:53:28 +0000 (14:53 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 30 Aug 2019 21:54:14 +0000 (14:54 -0700)
This commit adds `TensorRankOf<types, typeNames, ranks>` to specify ranked
tensor types with the specified types and ranks.  For example,
`TensorRankOf<[I32, F32], ["i32", "F32"], [0, 1]>` matches `tensor<i32>`,
`tensor<?xi32>`, `tensor<f32>`, or `tensor<?xf32>`.

PiperOrigin-RevId: 266461256

mlir/include/mlir/IR/OpBase.td
mlir/test/lib/TestDialect/TestOps.td
mlir/test/mlir-tblgen/types.mlir

index b4c921e..0ae17d0 100644 (file)
@@ -385,15 +385,6 @@ class StaticShapeTensorOf<list<Type> allowedTypes>
 
 def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
 
-// Whether a type is a ranked tensor type.
-def HasRankPred : CPred<"$_self.cast<ShapedType>().hasRank()">;
-
-// Whether a type is a ranked tensor type with one of the specified ranks.
-class HasAnyRankOfPred<list<int> ranks> : And<[
-    HasRankPred,
-    Or<!foreach(rank, ranks,
-                CPred<"$_self.cast<ShapedType>().getRank() == " # rank>)>]>;
-
 def I1Tensor   : TensorOf<[I1]>;
 def I8Tensor   : TensorOf<[I8]>;
 def I16Tensor  : TensorOf<[I16]>;
@@ -405,6 +396,27 @@ def F16Tensor  : TensorOf<[F16]>;
 def F32Tensor  : TensorOf<[F32]>;
 def F64Tensor  : TensorOf<[F64]>;
 
+// Whether a type is a ranked tensor type.
+def HasRankPred : CPred<"$_self.cast<ShapedType>().hasRank()">;
+
+// Whether a type is a ranked tensor type with one of the specified ranks.
+class HasAnyRankOfPred<list<int> ranks> : And<[
+    HasRankPred,
+    Or<!foreach(rank, ranks,
+                CPred<"$_self.cast<ShapedType>().getRank() == " # rank>)>]>;
+
+// Ranked tensor type with one of the specified types and ranks.
+class TensorRankOf<list<Type> allowedTypes, list<int> ranks> :
+    Type<And<[TensorOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
+         StrJoin<!foreach(rank, ranks, rank # "D"), "/">.result # " " #
+         TensorOf<allowedTypes>.description>;
+
+class 0DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [0]>;
+class 1DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [1]>;
+class 2DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [2]>;
+class 3DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [3]>;
+class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>;
+
 // Memref type.
 
 // Memrefs are blocks of data with fixed type and rank.
index f2d7aef..0010e1d 100644 (file)
@@ -50,10 +50,19 @@ def TakesStaticMemRefOp : TEST_Op<"takes_static_memref"> {
   let arguments = (ins AnyStaticShapeMemRef:$x);
 }
 
-def I32TensorRank0Or1Op : TEST_Op<"i32_tensor_rank_0_or_1"> {
+def NDTensorOfOp : TEST_Op<"nd_tensor_of"> {
   let arguments = (ins
-    Type<And<[I32Tensor.predicate, HasAnyRankOfPred<[0, 1]>]>,
-         "tensor<i32> or tensor<?xi32>">:$arg0
+    0DTensorOf<[F32]>:$arg0,
+    1DTensorOf<[F32]>:$arg1,
+    2DTensorOf<[I16]>:$arg2,
+    3DTensorOf<[I16]>:$arg3,
+    4DTensorOf<[I16]>:$arg4
+  );
+}
+
+def MultiTensorRankOf : TEST_Op<"multi_tensor_rank_of"> {
+  let arguments = (ins
+    TensorRankOf<[I8, I32, F32], [0, 1]>:$arg0
   );
 }
 
index f3ed197..6f4dfbb 100644 (file)
@@ -81,17 +81,52 @@ func @nested_tuple_multi_level_wrong_type() {
 
 // -----
 
-func @tensor_has_rank_0_or_1_success(%arg0: tensor<i32>, %arg1: tensor<5xi32>) {
-  "test.i32_tensor_rank_0_or_1"(%arg0) : (tensor<i32>) -> ()
-  "test.i32_tensor_rank_0_or_1"(%arg1) : (tensor<5xi32>) -> ()
+func @nd_tensor_of_success(%arg0: tensor<f32>, %arg1: tensor<10xf32>, %arg2: tensor<20x30xi16>, %arg3: tensor<40x50x60xi16>, %arg4: tensor<70x80x90x100xi16>) {
+  "test.nd_tensor_of"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<f32>, tensor<10xf32>, tensor<20x30xi16>, tensor<40x50x60xi16>, tensor<70x80x90x100xi16>) -> ()
   return
 }
 
 // -----
 
-func @tensor_has_rank_0_or_1_wrong_type(%arg0: tensor<2x2xi32>) {
-  // expected-error @+1 {{test.i32_tensor_rank_0_or_1' op operand #0 must be tensor<i32> or tensor<?xi32>}}
-  "test.i32_tensor_rank_0_or_1"(%arg0) : (tensor<2x2xi32>) -> ()
+func @nd_tensor_of_success_wrong_type_0d(%arg0: tensor<f32>, %arg1: tensor<10xf32>, %arg2: tensor<20x30xi16>, %arg3: tensor<40x50x60xi16>, %arg4: tensor<70x80x90x100xi32>) {
+  // expected-error @+1 {{'test.nd_tensor_of' op operand #0 must be 0D tensor of 32-bit float values}}
+  "test.nd_tensor_of"(%arg1, %arg1, %arg2, %arg3, %arg4) : (tensor<10xf32>, tensor<10xf32>, tensor<20x30xi16>, tensor<40x50x60xi16>, tensor<70x80x90x100xi32>) -> ()
+  return
+}
+
+// -----
+
+func @nd_tensor_of_success_wrong_type_4d(%arg0: tensor<f32>, %arg1: tensor<10xf32>, %arg2: tensor<20x30xi16>, %arg3: tensor<40x50x60xi16>, %arg4: tensor<70x80x90x100xi32>) {
+  // expected-error @+1 {{'test.nd_tensor_of' op operand #4 must be 4D tensor of 16-bit integer values}}
+  "test.nd_tensor_of"(%arg0, %arg1, %arg2, %arg3, %arg3) : (tensor<f32>, tensor<10xf32>, tensor<20x30xi16>, tensor<40x50x60xi16>, tensor<40x50x60xi16>) -> ()
+  return
+}
+
+// -----
+
+func @multi_tensor_rank_of_success(%arg0: tensor<i8>, %arg1: tensor<i32>, %arg2: tensor<f32>, %arg3: tensor<1xi8>, %arg4: tensor<1xi32>, %arg5: tensor<1xf32>) {
+  "test.multi_tensor_rank_of"(%arg0) : (tensor<i8>) -> ()
+  "test.multi_tensor_rank_of"(%arg1) : (tensor<i32>) -> ()
+  "test.multi_tensor_rank_of"(%arg2) : (tensor<f32>) -> ()
+  "test.multi_tensor_rank_of"(%arg3) : (tensor<1xi8>) -> ()
+  "test.multi_tensor_rank_of"(%arg4) : (tensor<1xi32>) -> ()
+  "test.multi_tensor_rank_of"(%arg5) : (tensor<1xf32>) -> ()
+  return
+}
+
+// -----
+
+func @multi_tensor_rank_of_wrong_unranked_type(%arg0: tensor<2x2xi8>) {
+  // expected-error @+1 {{'test.multi_tensor_rank_of' op operand #0 must be 0D/1D tensor of 8-bit integer or 32-bit integer or 32-bit float values}}
+  "test.multi_tensor_rank_of"(%arg0) : (tensor<2x2xi8>) -> ()
+  return
+}
+
+// -----
+
+func @multi_tensor_rank_of_wrong_element_type(%arg0: tensor<2xi16>) {
+  // expected-error @+1 {{'test.multi_tensor_rank_of' op operand #0 must be 0D/1D tensor of 8-bit integer or 32-bit integer or 32-bit float values}}
+  "test.multi_tensor_rank_of"(%arg0) : (tensor<2xi16>) -> ()
   return
 }