Add a type-constrained nested tuple type.
authorGeoffrey Martin-Noble <gcmn@google.com>
Fri, 24 May 2019 22:31:53 +0000 (15:31 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:01:03 +0000 (20:01 -0700)
    This is useful for dialects that use tuples but only support a subset of types.

--

PiperOrigin-RevId: 249910133

mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/StandardTypes.h
mlir/lib/IR/StandardTypes.cpp
mlir/test/TestDialect/TestOps.td
mlir/test/TestDialect/tests/types.mlir [deleted file]
mlir/test/mlir-tblgen/types.mlir [new file with mode: 0644]

index cc6999f..780d92f 100644 (file)
@@ -414,8 +414,6 @@ def F64MemRef  : MemRefOf<[F64]>;
 // This represents a generic tuple without any constraints on element type.
 def AnyTuple : Type<IsTupleTypePred, "tuple">;
 
-// TODO(b/130358239) Support typed tuples of arbitrary nesting.
-
 // A container type that has other types embedded in it, but (unlike
 // ContainerType) can hold elements with a mix of types. Requires a call that
 // produces a list of all elements' types.
@@ -443,6 +441,20 @@ class TupleOf<list<Type> allowedTypes>
     : MixedContainerType<AnyTypeOf<allowedTypes>, IsTupleTypePred,
                          "$_self.cast<TupleType>().getTypes()", "tuple">;
 
+// A Tuple with arbitrary nesting, where all elements are a mix of the allowed
+// types.
+class NestedTupleOf<list<Type> allowedTypes> :
+    MixedContainerType<AnyTypeOf<allowedTypes>, IsTupleTypePred,
+                       // TODO(b/133502599) Make it possible to use a C++ helper
+                       [{
+                         [&](){
+                           SmallVector<Type, 10> fTypes;
+                           $_self.cast<TupleType>().getFlattenedTypes(fTypes);
+                           return fTypes;
+                         }()
+                       }],
+                       "nested tuple">;
+
 //===----------------------------------------------------------------------===//
 // Common type constraints
 //===----------------------------------------------------------------------===//
index 6733d75..de11f87 100644 (file)
@@ -481,6 +481,12 @@ public:
   /// Return the elements types for this tuple.
   ArrayRef<Type> getTypes() const;
 
+  /// Accumulate the types contained in this tuple and tuples nested within it.
+  /// Note that this only flattens nested tuples, not any other container type,
+  /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
+  /// (i32, tensor<i32>, f32, i64)
+  void getFlattenedTypes(SmallVectorImpl<Type> &types);
+
   /// Return the number of held types.
   unsigned size() const;
 
index b279d19..1958a43 100644 (file)
@@ -416,5 +416,18 @@ TupleType TupleType::get(ArrayRef<Type> elementTypes, MLIRContext *context) {
 /// Return the elements types for this tuple.
 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
 
+/// Accumulate the types contained in this tuple and tuples nested within it.
+/// Note that this only flattens nested tuples, not any other container type,
+/// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
+/// (i32, tensor<i32>, f32, i64)
+void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
+  for (Type type : getTypes()) {
+    if (auto nestedTuple = type.dyn_cast<TupleType>())
+      nestedTuple.getFlattenedTypes(types);
+    else
+      types.push_back(type);
+  }
+}
+
 /// Return the number of element types.
 unsigned TupleType::size() const { return getImpl()->size(); }
index 9bdfb63..984b9f2 100644 (file)
@@ -54,4 +54,8 @@ def TupleOp : TEST_Op<"tuple_32_bit"> {
   let results = (outs TupleOf<[I32, F32]>);
 }
 
+def NestedTupleOp : TEST_Op<"nested_tuple_32_bit"> {
+  let results = (outs NestedTupleOf<[I32, F32]>);
+}
+
 #endif // TEST_OPS
\ No newline at end of file
diff --git a/mlir/test/TestDialect/tests/types.mlir b/mlir/test/TestDialect/tests/types.mlir
deleted file mode 100644 (file)
index b402e46..0000000
+++ /dev/null
@@ -1,40 +0,0 @@
-// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s
-
-// -----
-
-// CHECK-LABEL: @tuple_success
-func @tuple_success() {
-  %0 = "test.tuple_32_bit"() : () -> (tuple<i32>)
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @tuple_mixed_success
-func @tuple_mixed_success() {
-  %0 = "test.tuple_32_bit"() : () -> (tuple<i32, f32>)
-  return
-}
-
-// -----
-
-func @tuple_empty_success() {
-  %0 = "test.tuple_32_bit"() : () -> (tuple<>)
-  return
-}
-
-// -----
-
-func @tuple_wrong_type_scalar() {
-  // expected-error@+1 {{must be tuple with any combination of 32-bit integer or 32-bit float values}}
-  %0 = "test.tuple_32_bit"() : () -> (tuple<i64>)
-  return
-}
-
-// -----
-
-func @tuple_wrong_type_tensor() {
-  // expected-error@+1 {{must be tuple with any combination of 32-bit integer or 32-bit float values}}
-  %0 = "test.tuple_32_bit"() : () -> (tuple<tensor<i32>>)
-  return
-}
\ No newline at end of file
diff --git a/mlir/test/mlir-tblgen/types.mlir b/mlir/test/mlir-tblgen/types.mlir
new file mode 100644 (file)
index 0000000..1472487
--- /dev/null
@@ -0,0 +1,81 @@
+// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: @tuple_success
+func @tuple_success() {
+  %0 = "test.tuple_32_bit"() : () -> (tuple<i32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @tuple_mixed_success
+func @tuple_mixed_success() {
+  %0 = "test.tuple_32_bit"() : () -> (tuple<i32, f32>)
+  return
+}
+
+// -----
+
+func @tuple_empty_success() {
+  %0 = "test.tuple_32_bit"() : () -> (tuple<>)
+  return
+}
+
+// -----
+
+func @tuple_wrong_type_scalar() {
+  // expected-error@+1 {{must be tuple with any combination of 32-bit integer or 32-bit float values}}
+  %0 = "test.tuple_32_bit"() : () -> (tuple<i64>)
+  return
+}
+
+// -----
+
+func @tuple_wrong_type_tensor() {
+  // expected-error@+1 {{must be tuple with any combination of 32-bit integer or 32-bit float values}}
+  %0 = "test.tuple_32_bit"() : () -> (tuple<tensor<i32>>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @nested_tuple_empty_success
+func @nested_tuple_empty_success() {
+  %0 = "test.nested_tuple_32_bit"() : () -> (tuple<>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @nested_tuple_one_level_success
+func @nested_tuple_one_level_success() {
+  %0 = "test.nested_tuple_32_bit"() : () -> (tuple<i32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @nested_tuple_multi_level_success
+func @nested_tuple_multi_level_success() {
+  %0 = "test.nested_tuple_32_bit"() : () -> (tuple<i32, tuple<i32, tuple<i32>>>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @nested_tuple_multi_level_mixed_success
+func @nested_tuple_multi_level_mixed_success() {
+  %0 = "test.nested_tuple_32_bit"() : () -> (tuple<i32, tuple<f32, tuple<i32>>>)
+  return
+}
+
+// -----
+
+func @nested_tuple_multi_level_wrong_type() {
+  // expected-error@+1 {{must be nested tuple with any combination of 32-bit integer or 32-bit float values}}
+  %0 = "test.nested_tuple_32_bit"() : () -> (tuple<i32, tuple<i32, tuple<i64>>>)
+  return
+}
+