[mlir][transform] SplitHandleOp: add additional distribution options
authorMatthias Springer <me@m-sp.org>
Tue, 9 May 2023 09:37:13 +0000 (11:37 +0200)
committerMatthias Springer <me@m-sp.org>
Tue, 9 May 2023 09:38:18 +0000 (11:38 +0200)
Add options to handle cases where there are not enough or too many payload ops mapped to the given handle.

Differential Revision: https://reviews.llvm.org/D149955

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir

index 8154835..0408b34 100644 (file)
@@ -521,8 +521,8 @@ def SplitHandleOp : TransformDialectOp<"split_handle",
   let description = [{
     Splits `handle` into one or multiple handles, as specified by the number
     of results of this operation. `handle` should be mapped to as many payload
-    ops as there are results. Otherwise, this transform will fail silently.
-    Each result handle is mapped to exactly one payload op. The order
+    ops as there are results. Otherwise, this transform will fail silently by
+    default. Each result handle is mapped to exactly one payload op. The order
     of the payload ops is preserved, i.e., the i-th payload op is mapped to the
     i-th result handle.
 
@@ -530,12 +530,23 @@ def SplitHandleOp : TransformDialectOp<"split_handle",
     operations are tracked by the source `handle` and to extract them into
     individual handles that can be further manipulated in isolation.
 
-    If `handle` is empty, this transform will succeed and all result handles
-    are empty.
+    If there are more payload ops than results, the remaining ops are mapped to
+    the result with index `overflow_result`. If no `overflow_result` is
+    specified, the transform fails silently.
+
+    If there are fewer payload ops than results, the transform fails silently
+    if `fail_on_payload_too_small` is set to "true". Otherwise, it succeeds and
+    the remaining result handles are not mapped to any op. It also succeeds if
+    `handle` is empty and `pass_through_empty_handle` is set to "true",
+    regardless of `fail_on_payload_too_small`.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$handle);
+  let arguments = (ins TransformHandleTypeInterface:$handle,
+                       DefaultValuedAttr<BoolAttr, "true">:$pass_through_empty_handle,
+                       DefaultValuedAttr<BoolAttr, "true">:$fail_on_payload_too_small,
+                       OptionalAttr<I64Attr>:$overflow_result);
   let results = (outs Variadic<TransformHandleTypeInterface>:$results);
+  let hasVerifier = 1;
 
   let builders = [
     OpBuilder<(ins "Value":$handle, "int64_t":$numResultHandles)>
index 5c39ccc..62ef94d 100644 (file)
@@ -1502,24 +1502,40 @@ DiagnosedSilenceableFailure
 transform::SplitHandleOp::apply(transform::TransformResults &results,
                                 transform::TransformState &state) {
   int64_t numPayloadOps = state.getPayloadOps(getHandle()).size();
-
-  // Empty handle corner case: all result handles are empty.
-  if (numPayloadOps == 0) {
-    for (OpResult result : getResults())
-      results.set(result, {});
-    return DiagnosedSilenceableFailure::success();
-  }
-
-  // If the input handle was not empty and the number of payload ops does not
-  // match, this is a legit silenceable error.
-  if (numPayloadOps != getNumResults())
+  auto produceNumOpsError = [&]() {
     return emitSilenceableError()
-           << getHandle() << " expected to contain " << getNumResults()
+           << getHandle() << " expected to contain " << this->getNumResults()
            << " payload ops but it contains " << numPayloadOps
            << " payload ops";
+  };
 
-  for (const auto &en : llvm::enumerate(state.getPayloadOps(getHandle())))
-    results.set(getResults()[en.index()].cast<OpResult>(), en.value());
+  // Fail if there are more payload ops than results and no overflow result was
+  // specified.
+  if (numPayloadOps > getNumResults() && !getOverflowResult().has_value())
+    return produceNumOpsError();
+
+  // Fail if there are more results than payload ops. Unless:
+  // - "fail_on_payload_too_small" is set to "false", or
+  // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
+  if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() &&
+      !(numPayloadOps == 0 && getPassThroughEmptyHandle()))
+    return produceNumOpsError();
+
+  // Distribute payload ops.
+  SmallVector<SmallVector<Operation *, 1>> resultHandles(getNumResults(), {});
+  if (getOverflowResult())
+    resultHandles[*getOverflowResult()].reserve(numPayloadOps -
+                                                getNumResults());
+  for (auto &&en : llvm::enumerate(state.getPayloadOps(getHandle()))) {
+    int64_t resultNum = en.index();
+    if (resultNum >= getNumResults())
+      resultNum = *getOverflowResult();
+    resultHandles[resultNum].push_back(en.value());
+  }
+
+  // Set transform op results.
+  for (auto &&it : llvm::enumerate(resultHandles))
+    results.set(getResult(it.index()).cast<OpResult>(), it.value());
 
   return DiagnosedSilenceableFailure::success();
 }
@@ -1532,6 +1548,13 @@ void transform::SplitHandleOp::getEffects(
   // manipulation.
 }
 
+LogicalResult transform::SplitHandleOp::verify() {
+  if (getOverflowResult().has_value() &&
+      !(*getOverflowResult() >= 0 && *getOverflowResult() < getNumResults()))
+    return emitOpError("overflow_result is not a valid result index");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // PDLMatchOp
 //===----------------------------------------------------------------------===//
index 9ed9957..8ceb72d 100644 (file)
@@ -858,6 +858,47 @@ transform.sequence failures(suppress) {
 
 // -----
 
+func.func @split_handle(%a: index, %b: index, %c: index) {
+  %0 = arith.muli %a, %b : index
+  %1 = arith.muli %a, %c : index
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%fun: !pdl.operation):
+  %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation
+  // No error, last result handle is empty.
+  %h:3 = split_handle %muli_2 {fail_on_payload_too_small = false} : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
+  // expected-remark @below {{1}}
+  transform.test_print_number_of_associated_payload_ir_ops %h#0
+  // expected-remark @below {{1}}
+  transform.test_print_number_of_associated_payload_ir_ops %h#1
+  // expected-remark @below {{0}}
+  transform.test_print_number_of_associated_payload_ir_ops %h#2
+}
+
+// -----
+
+func.func @split_handle(%a: index, %b: index, %c: index) {
+  %0 = arith.muli %a, %b : index
+  %1 = arith.muli %a, %c : index
+  %2 = arith.muli %a, %c : index
+  %3 = arith.muli %a, %c : index
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%fun: !pdl.operation):
+  %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation
+  %h:2 = split_handle %muli_2 {overflow_result = 0} : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+  // expected-remark @below {{3}}
+  transform.test_print_number_of_associated_payload_ir_ops %h#0
+  // expected-remark @below {{1}}
+  transform.test_print_number_of_associated_payload_ir_ops %h#1
+}
+
+// -----
+
 "test.some_op"() : () -> ()
 "other_dialect.other_op"() : () -> ()