[mlir][func] Add support for nested tuples to TestDecomposeCallGraphTypes.
authorIngo Müller <ingomueller@google.com>
Wed, 8 Feb 2023 13:10:00 +0000 (13:10 +0000)
committerIngo Müller <ingomueller@google.com>
Thu, 9 Feb 2023 05:22:01 +0000 (05:22 +0000)
Nested tuples were only supported in some narrow edge cases (and
potentially only because the test ops like `test.make_tuple` aren't
properly verified). This patch adds a couple of test cases with tested
tuple types and makes them work in the test pass by extending the
argument materialization and decomposition functions.

Reviewed By: silvas

Differential Revision: https://reviews.llvm.org/D143579

mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
mlir/test/Transforms/decompose-call-graph-types.mlir
mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp

index 6f27cbb..29bab1d 100644 (file)
@@ -53,8 +53,8 @@ public:
 
   /// This method registers a callback function that will be called to decompose
   /// a value of a certain type into 0, 1, or multiple values.
-  template <typename FnT,
-            typename T = typename llvm::function_traits<FnT>::template arg_t<2>>
+  template <typename FnT, typename T = typename llvm::function_traits<
+                              std::decay_t<FnT>>::template arg_t<2>>
   void addDecomposeValueConversion(FnT &&callback) {
     decomposeValueConversions.emplace_back(
         wrapDecomposeValueConversionCallback<T>(std::forward<FnT>(callback)));
index 5ecbad1..604e948 100644 (file)
@@ -37,6 +37,29 @@ func.func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tupl
 
 // -----
 
+// Test case: Type that needs to be recursively decomposed at different recursion depths.
+
+// CHECK-LABEL:   func @mixed_recursive_decomposition(
+// CHECK-SAME:                 %[[ARG0:.*]]: i1,
+// CHECK-SAME:                 %[[ARG1:.*]]: i2) -> (i1, i2) {
+// CHECK:           %[[V0:.*]] = "test.make_tuple"() : () -> tuple<>
+// CHECK:           %[[V1:.*]] = "test.make_tuple"(%[[ARG0]]) : (i1) -> tuple<i1>
+// CHECK:           %[[V2:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2>
+// CHECK:           %[[V3:.*]] = "test.make_tuple"(%[[V2]]) : (tuple<i2>) -> tuple<tuple<i2>>
+// CHECK:           %[[V4:.*]] = "test.make_tuple"(%[[V0]], %[[V1]], %[[V3]]) : (tuple<>, tuple<i1>, tuple<tuple<i2>>) -> tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>
+// CHECK:           %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<>
+// CHECK:           %[[V6:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 1 : i32} : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<i1>
+// CHECK:           %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple<i1>) -> i1
+// CHECK:           %[[V8:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 2 : i32} : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
+// CHECK:           %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) {index = 0 : i32} : (tuple<tuple<i2>>) -> tuple<i2>
+// CHECK:           %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) {index = 0 : i32} : (tuple<i2>) -> i2
+// CHECK:           return %[[V7]], %[[V10]] : i1, i2
+func.func @mixed_recursive_decomposition(%arg0: tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>> {
+  return %arg0 : tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>
+}
+
+// -----
+
 // Test case: Check decomposition of calls.
 
 // CHECK-LABEL:   func private @callee(i1, i32) -> (i1, i32)
@@ -89,6 +112,26 @@ func.func @unconverted_op_result() -> tuple<i1, i32> {
 
 // -----
 
+// Test case: Ensure decompositions are inserted properly around results of
+// unconverted ops in the case of different nesting levels.
+
+// CHECK-LABEL:   func @nested_unconverted_op_result(
+// CHECK-SAME:                 %[[ARG0:.*]]: i1,
+// CHECK-SAME:                 %[[ARG1:.*]]: i32) -> (i1, i32) {
+// CHECK:           %[[V0:.*]] = "test.make_tuple"(%[[ARG1]]) : (i32) -> tuple<i32>
+// CHECK:           %[[V1:.*]] = "test.make_tuple"(%[[ARG0]], %[[V0]]) : (i1, tuple<i32>) -> tuple<i1, tuple<i32>>
+// CHECK:           %[[V2:.*]] = "test.op"(%[[V1]]) : (tuple<i1, tuple<i32>>) -> tuple<i1, tuple<i32>>
+// CHECK:           %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 0 : i32} : (tuple<i1, tuple<i32>>) -> i1
+// CHECK:           %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple<i1, tuple<i32>>) -> tuple<i32>
+// CHECK:           %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple<i32>) -> i32
+// CHECK:           return %[[V3]], %[[V5]] : i1, i32
+func.func @nested_unconverted_op_result(%arg: tuple<i1, tuple<i32>>) -> tuple<i1, tuple<i32>> {
+  %0 = "test.op"(%arg) : (tuple<i1, tuple<i32>>) -> (tuple<i1, tuple<i32>>)
+  return %0 : tuple<i1, tuple<i32>>
+}
+
+// -----
+
 // Test case: Check mixed decomposed and non-decomposed args.
 // This makes sure to test the cases if 1:0, 1:1, and 1:N decompositions.
 
index 9492d23..41e1666 100644 (file)
 using namespace mlir;
 
 namespace {
+/// Creates a sequence of `test.get_tuple_element` ops for all elements of a
+/// given tuple value. If some tuple elements are, in turn, tuples, the elements
+/// of those are extracted recursively such that the returned values have the
+/// same types as `resultTypes.getFlattenedTypes()`.
+static LogicalResult buildDecomposeTuple(OpBuilder &builder, Location loc,
+                                         TupleType resultType, Value value,
+                                         SmallVectorImpl<Value> &values) {
+  for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
+    Type elementType = resultType.getType(i);
+    Value element = builder.create<test::GetTupleElementOp>(
+        loc, elementType, value, builder.getI32IntegerAttr(i));
+    if (auto nestedTupleType = elementType.dyn_cast<TupleType>()) {
+      // Recurse if the current element is also a tuple.
+      if (failed(buildDecomposeTuple(builder, loc, nestedTupleType, element,
+                                     values)))
+        return failure();
+    } else {
+      values.push_back(element);
+    }
+  }
+  return success();
+}
+
+/// Creates a `test.make_tuple` op out of the given inputs building a tuple of
+/// type `resultType`. If that type is nested, each nested tuple is built
+/// recursively with another `test.make_tuple` op.
+static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
+                                             TupleType resultType,
+                                             ValueRange inputs, Location loc) {
+  // Build one value for each element at this nesting level.
+  SmallVector<Value> elements;
+  elements.reserve(resultType.getTypes().size());
+  ValueRange::iterator inputIt = inputs.begin();
+  for (Type elementType : resultType.getTypes()) {
+    if (auto nestedTupleType = elementType.dyn_cast<TupleType>()) {
+      // Determine how many input values are needed for the nested elements of
+      // the nested TupleType and advance inputIt by that number.
+      // TODO: We only need the *number* of nested types, not the types itself.
+      //       Maybe it's worth adding a more efficient overload?
+      SmallVector<Type> nestedFlattenedTypes;
+      nestedTupleType.getFlattenedTypes(nestedFlattenedTypes);
+      size_t numNestedFlattenedTypes = nestedFlattenedTypes.size();
+      ValueRange nestedFlattenedelements(inputIt,
+                                         inputIt + numNestedFlattenedTypes);
+      inputIt += numNestedFlattenedTypes;
+
+      // Recurse on the values for the nested TupleType.
+      std::optional<Value> res = buildMakeTupleOp(builder, nestedTupleType,
+                                                  nestedFlattenedelements, loc);
+      if (!res.has_value())
+        return {};
+
+      // The tuple constructed by the conversion is the element value.
+      elements.push_back(res.value());
+    } else {
+      // Base case: take one input as is.
+      elements.push_back(*inputIt++);
+    }
+  }
+
+  // Assemble the tuple from the elements.
+  return builder.create<test::MakeTupleOp>(loc, resultType, elements);
+}
+
 /// A pass for testing call graph type decomposition.
 ///
 /// This instantiates the patterns with a TypeConverter and ValueDecomposer
@@ -39,7 +103,6 @@ struct TestDecomposeCallGraphTypes
     auto *context = &getContext();
     TypeConverter typeConverter;
     ConversionTarget target(*context);
-    ValueDecomposer decomposer;
     RewritePatternSet patterns(context);
 
     target.addLegalDialect<test::TestDialect>();
@@ -59,27 +122,10 @@ struct TestDecomposeCallGraphTypes
           tupleType.getFlattenedTypes(types);
           return success();
         });
+    typeConverter.addArgumentMaterialization(buildMakeTupleOp);
 
-    decomposer.addDecomposeValueConversion([](OpBuilder &builder, Location loc,
-                                              TupleType resultType, Value value,
-                                              SmallVectorImpl<Value> &values) {
-      for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
-        Value res = builder.create<test::GetTupleElementOp>(
-            loc, resultType.getType(i), value, builder.getI32IntegerAttr(i));
-        values.push_back(res);
-      }
-      return success();
-    });
-
-    typeConverter.addArgumentMaterialization(
-        [](OpBuilder &builder, TupleType resultType, ValueRange inputs,
-           Location loc) -> std::optional<Value> {
-          if (inputs.size() == 1)
-            return std::nullopt;
-          TupleType tuple = builder.getTupleType(inputs.getTypes());
-          Value value = builder.create<test::MakeTupleOp>(loc, tuple, inputs);
-          return value;
-        });
+    ValueDecomposer decomposer;
+    decomposer.addDecomposeValueConversion(buildDecomposeTuple);
 
     populateDecomposeCallGraphTypesPatterns(context, typeConverter, decomposer,
                                             patterns);