From 9f1f91e7703bdf2ed88fd76da24b23681ac2e7b1 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Fri, 24 May 2019 15:31:53 -0700 Subject: [PATCH] Add a type-constrained nested tuple type. This is useful for dialects that use tuples but only support a subset of types. -- PiperOrigin-RevId: 249910133 --- mlir/include/mlir/IR/OpBase.td | 16 ++++++- mlir/include/mlir/IR/StandardTypes.h | 6 +++ mlir/lib/IR/StandardTypes.cpp | 13 ++++++ mlir/test/TestDialect/TestOps.td | 4 ++ mlir/test/TestDialect/tests/types.mlir | 40 ----------------- mlir/test/mlir-tblgen/types.mlir | 81 ++++++++++++++++++++++++++++++++++ 6 files changed, 118 insertions(+), 42 deletions(-) delete mode 100644 mlir/test/TestDialect/tests/types.mlir create mode 100644 mlir/test/mlir-tblgen/types.mlir diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index cc6999f..780d92f 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -414,8 +414,6 @@ def F64MemRef : MemRefOf<[F64]>; // This represents a generic tuple without any constraints on element type. def AnyTuple : Type; -// TODO(b/130358239) Support typed tuples of arbitrary nesting. - // A container type that has other types embedded in it, but (unlike // ContainerType) can hold elements with a mix of types. Requires a call that // produces a list of all elements' types. @@ -443,6 +441,20 @@ class TupleOf allowedTypes> : MixedContainerType, IsTupleTypePred, "$_self.cast().getTypes()", "tuple">; +// A Tuple with arbitrary nesting, where all elements are a mix of the allowed +// types. +class NestedTupleOf allowedTypes> : + MixedContainerType, IsTupleTypePred, + // TODO(b/133502599) Make it possible to use a C++ helper + [{ + [&](){ + SmallVector fTypes; + $_self.cast().getFlattenedTypes(fTypes); + return fTypes; + }() + }], + "nested tuple">; + //===----------------------------------------------------------------------===// // Common type constraints //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h index 6733d75..de11f87 100644 --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -481,6 +481,12 @@ public: /// Return the elements types for this tuple. ArrayRef getTypes() const; + /// Accumulate the types contained in this tuple and tuples nested within it. + /// Note that this only flattens nested tuples, not any other container type, + /// e.g. a tuple, tuple>> is flattened to + /// (i32, tensor, f32, i64) + void getFlattenedTypes(SmallVectorImpl &types); + /// Return the number of held types. unsigned size() const; diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index b279d19..1958a43 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -416,5 +416,18 @@ TupleType TupleType::get(ArrayRef elementTypes, MLIRContext *context) { /// Return the elements types for this tuple. ArrayRef TupleType::getTypes() const { return getImpl()->getTypes(); } +/// Accumulate the types contained in this tuple and tuples nested within it. +/// Note that this only flattens nested tuples, not any other container type, +/// e.g. a tuple, tuple>> is flattened to +/// (i32, tensor, f32, i64) +void TupleType::getFlattenedTypes(SmallVectorImpl &types) { + for (Type type : getTypes()) { + if (auto nestedTuple = type.dyn_cast()) + nestedTuple.getFlattenedTypes(types); + else + types.push_back(type); + } +} + /// Return the number of element types. unsigned TupleType::size() const { return getImpl()->size(); } diff --git a/mlir/test/TestDialect/TestOps.td b/mlir/test/TestDialect/TestOps.td index 9bdfb63..984b9f2 100644 --- a/mlir/test/TestDialect/TestOps.td +++ b/mlir/test/TestDialect/TestOps.td @@ -54,4 +54,8 @@ def TupleOp : TEST_Op<"tuple_32_bit"> { let results = (outs TupleOf<[I32, F32]>); } +def NestedTupleOp : TEST_Op<"nested_tuple_32_bit"> { + let results = (outs NestedTupleOf<[I32, F32]>); +} + #endif // TEST_OPS \ No newline at end of file diff --git a/mlir/test/TestDialect/tests/types.mlir b/mlir/test/TestDialect/tests/types.mlir deleted file mode 100644 index b402e46..0000000 --- a/mlir/test/TestDialect/tests/types.mlir +++ /dev/null @@ -1,40 +0,0 @@ -// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s - -// ----- - -// CHECK-LABEL: @tuple_success -func @tuple_success() { - %0 = "test.tuple_32_bit"() : () -> (tuple) - return -} - -// ----- - -// CHECK-LABEL: @tuple_mixed_success -func @tuple_mixed_success() { - %0 = "test.tuple_32_bit"() : () -> (tuple) - return -} - -// ----- - -func @tuple_empty_success() { - %0 = "test.tuple_32_bit"() : () -> (tuple<>) - return -} - -// ----- - -func @tuple_wrong_type_scalar() { - // expected-error@+1 {{must be tuple with any combination of 32-bit integer or 32-bit float values}} - %0 = "test.tuple_32_bit"() : () -> (tuple) - return -} - -// ----- - -func @tuple_wrong_type_tensor() { - // expected-error@+1 {{must be tuple with any combination of 32-bit integer or 32-bit float values}} - %0 = "test.tuple_32_bit"() : () -> (tuple>) - return -} \ No newline at end of file diff --git a/mlir/test/mlir-tblgen/types.mlir b/mlir/test/mlir-tblgen/types.mlir new file mode 100644 index 0000000..1472487 --- /dev/null +++ b/mlir/test/mlir-tblgen/types.mlir @@ -0,0 +1,81 @@ +// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s + +// ----- + +// CHECK-LABEL: @tuple_success +func @tuple_success() { + %0 = "test.tuple_32_bit"() : () -> (tuple) + return +} + +// ----- + +// CHECK-LABEL: @tuple_mixed_success +func @tuple_mixed_success() { + %0 = "test.tuple_32_bit"() : () -> (tuple) + return +} + +// ----- + +func @tuple_empty_success() { + %0 = "test.tuple_32_bit"() : () -> (tuple<>) + return +} + +// ----- + +func @tuple_wrong_type_scalar() { + // expected-error@+1 {{must be tuple with any combination of 32-bit integer or 32-bit float values}} + %0 = "test.tuple_32_bit"() : () -> (tuple) + return +} + +// ----- + +func @tuple_wrong_type_tensor() { + // expected-error@+1 {{must be tuple with any combination of 32-bit integer or 32-bit float values}} + %0 = "test.tuple_32_bit"() : () -> (tuple>) + return +} + +// ----- + +// CHECK-LABEL: @nested_tuple_empty_success +func @nested_tuple_empty_success() { + %0 = "test.nested_tuple_32_bit"() : () -> (tuple<>) + return +} + +// ----- + +// CHECK-LABEL: @nested_tuple_one_level_success +func @nested_tuple_one_level_success() { + %0 = "test.nested_tuple_32_bit"() : () -> (tuple) + return +} + +// ----- + +// CHECK-LABEL: @nested_tuple_multi_level_success +func @nested_tuple_multi_level_success() { + %0 = "test.nested_tuple_32_bit"() : () -> (tuple>>) + return +} + +// ----- + +// CHECK-LABEL: @nested_tuple_multi_level_mixed_success +func @nested_tuple_multi_level_mixed_success() { + %0 = "test.nested_tuple_32_bit"() : () -> (tuple>>) + return +} + +// ----- + +func @nested_tuple_multi_level_wrong_type() { + // expected-error@+1 {{must be nested tuple with any combination of 32-bit integer or 32-bit float values}} + %0 = "test.nested_tuple_32_bit"() : () -> (tuple>>) + return +} + -- 2.7.4