[mlir][Transform] Add a transform.split_handles operation and fix general silenceable...
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 7 Oct 2022 08:43:38 +0000 (01:43 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 7 Oct 2022 16:01:34 +0000 (09:01 -0700)
The transform.split_handles op is useful for ensuring a statically known number of operations are
tracked by the source `handle` and to extract them into individual handles
that can be further manipulated in isolation.

In the process of making the op robust wrt to silenceable errors and the suppress mode, issues were
uncovered and fixed.

The main issue was that silenceable errors were short-circuited too early and the payloads were not
set. This resulted in suppressed silenceable errors not propagating correctly.
Fixing the issue triggered a few test failures: silenceable error returns now must properly set the results state.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D135426

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp

index 4b97992..817c1c6 100644 (file)
@@ -210,6 +210,34 @@ def MergeHandlesOp : TransformDialectOp<"merge_handles",
   let hasFolder = 1;
 }
 
+def SplitHandlesOp : TransformDialectOp<"split_handles",
+    [FunctionalStyleTransformOpTrait,
+     DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let summary = "Splits handles from a union of payload ops to a list";
+  let description = [{
+    Creates `num_result_handles` transform IR handles extracted from the 
+    `handle` operand. The resulting Payload IR operation handles are listed
+    in the same order as the operations appear in the source `handle`.
+    This is useful for ensuring a statically known number of operations are
+    tracked by the source `handle` and to extract them into individual handles
+    that can be further manipulated in isolation.
+
+    This operation succeeds and returns `num_result_handles` if the statically
+    specified `num_result_handles` corresponds to the dynamic number of 
+    operations contained in the source `handle`. Otherwise it silently fails.
+  }];
+
+  let arguments = (ins PDL_Operation:$handle,
+                       I64Attr:$num_result_handles);
+  let results = (outs Variadic<PDL_Operation>:$results);
+  let assemblyFormat = [{
+    $handle `in` `[` $num_result_handles `]` 
+    custom<StaticNumPDLResults>(type($results), ref($num_result_handles))
+    attr-dict
+  }];
+}
+
 def PDLMatchOp : TransformDialectOp<"pdl_match",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
index 4245d4c..99f93ed 100644 (file)
@@ -590,9 +590,11 @@ transform::MatchOp::apply(transform::TransformResults &results,
                 getOps()->getAsValueRange<StringAttr>().end());
 
   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
-  if (payloadOps.size() != 1)
+  if (payloadOps.size() != 1) {
+    results.set(getResult().cast<OpResult>(), {});
     return DiagnosedSilenceableFailure(
         this->emitOpError("requires exactly one target handle"));
+  }
 
   SmallVector<Operation *> res;
   auto matchFun = [&](Operation *op) {
@@ -877,8 +879,11 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
           }
           return OpFoldResult(op->getResult(0));
         }));
-    if (!diag.succeeded())
+    if (diag.isSilenceableFailure()) {
+      results.set(getFirst().cast<OpResult>(), {});
+      results.set(getSecond().cast<OpResult>(), {});
       return diag;
+    }
 
     if (splitPoints.size() != payload.size()) {
       emitError() << "expected the dynamic split point handle to point to as "
@@ -900,6 +905,8 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
     if (!linalgOp) {
       auto diag = emitSilenceableError() << "only applies to structured ops";
       diag.attachNote(target->getLoc()) << "target op";
+      results.set(getFirst().cast<OpResult>(), {});
+      results.set(getSecond().cast<OpResult>(), {});
       return diag;
     }
 
@@ -907,6 +914,8 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
       auto diag = emitSilenceableError() << "dimension " << getDimension()
                                          << " does not exist in target op";
       diag.attachNote(target->getLoc()) << "target op";
+      results.set(getFirst().cast<OpResult>(), {});
+      results.set(getSecond().cast<OpResult>(), {});
       return diag;
     }
 
index 5f32e78..fde8f20 100644 (file)
@@ -47,6 +47,7 @@ transform::GetParentForOp::apply(transform::TransformResults &results,
                                            << scf::ForOp::getOperationName()
                                            << "' parent";
         diag.attachNote(target->getLoc()) << "target op";
+        results.set(getResult().cast<OpResult>(), {});
         return diag;
       }
       current = loop;
@@ -100,6 +101,7 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
       DiagnosedSilenceableFailure diag = emitSilenceableError()
                                          << "failed to outline";
       diag.attachNote(target->getLoc()) << "target op";
+      results.set(getTransformed().cast<OpResult>(), {});
       return diag;
     }
     func::CallOp call;
index 1d841f9..176e93a 100644 (file)
@@ -225,8 +225,11 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
   }
 
   transform::TransformResults results(transform->getNumResults());
+  // Compute the result but do not short-circuit the silenceable failure case as
+  // we still want the handles to propagate properly so the "suppress" mode can
+  // proceed on a best effort basis.
   DiagnosedSilenceableFailure result(transform.apply(results, *this));
-  if (!result.succeeded())
+  if (result.isDefiniteFailure())
     return result;
 
   // Remove the mapping for the operand if it is consumed by the operation. This
@@ -258,7 +261,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
     DBGS() << "Top-level payload:\n";
     getTopLevel()->print(llvm::dbgs());
   });
-  return DiagnosedSilenceableFailure::success();
+  return result;
 }
 
 //===----------------------------------------------------------------------===//
index 26a2baf..126500f 100644 (file)
@@ -23,6 +23,7 @@
 
 using namespace mlir;
 
+/// Custom parser for ReplicateOp.
 static ParseResult parsePDLOpTypedResults(
     OpAsmParser &parser, SmallVectorImpl<Type> &types,
     const SmallVectorImpl<OpAsmParser::UnresolvedOperand> &handles) {
@@ -30,9 +31,23 @@ static ParseResult parsePDLOpTypedResults(
   return success();
 }
 
+/// Custom printer for ReplicateOp.
 static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange,
                                    ValueRange) {}
 
+/// Custom parser for SplitHandlesOp.
+static ParseResult parseStaticNumPDLResults(OpAsmParser &parser,
+                                            SmallVectorImpl<Type> &types,
+                                            IntegerAttr numHandlesAttr) {
+  types.resize(numHandlesAttr.getInt(),
+               pdl::OperationType::get(parser.getContext()));
+  return success();
+}
+
+/// Custom printer for SplitHandlesOp.
+static void printStaticNumPDLResults(OpAsmPrinter &, Operation *, TypeRange,
+                                     IntegerAttr) {}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
 
@@ -453,6 +468,46 @@ OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) {
 }
 
 //===----------------------------------------------------------------------===//
+// SplitHandlesOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::SplitHandlesOp::apply(transform::TransformResults &results,
+                                 transform::TransformState &state) {
+  int64_t numResultHandles =
+      getHandle() ? state.getPayloadOps(getHandle()).size() : 0;
+  int64_t expectedNumResultHandles = getNumResultHandles();
+  if (numResultHandles != expectedNumResultHandles) {
+    // Failing case needs to propagate gracefully for both suppress and
+    // propagate modes.
+    for (int64_t idx = 0; idx < expectedNumResultHandles; ++idx)
+      results.set(getResults()[idx].cast<OpResult>(), {});
+    // Empty input handle corner case: always propagates empty handles in both
+    // suppress and propagate modes.
+    if (numResultHandles == 0)
+      return DiagnosedSilenceableFailure::success();
+    // If the input handle was not empty and the number of result handles does
+    // not match, this is a legit silenceable error.
+    return emitSilenceableError()
+           << getHandle() << " expected to contain " << expectedNumResultHandles
+           << " operation handles but it only contains " << numResultHandles
+           << " handles";
+  }
+  // Normal successful case.
+  for (auto en : llvm::enumerate(state.getPayloadOps(getHandle())))
+    results.set(getResults()[en.index()].cast<OpResult>(), en.value());
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SplitHandlesOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getHandle(), effects);
+  producesHandle(getResults(), effects);
+  // There are no effects on the Payload IR as this is only a handle
+  // manipulation.
+}
+
+//===----------------------------------------------------------------------===//
 // PDLMatchOp
 //===----------------------------------------------------------------------===//
 
index 7c0deb1..ddfd645 100644 (file)
@@ -761,3 +761,42 @@ transform.sequence failures(propagate) {
 
 }
 
+// -----
+
+func.func @split_handles(%a: index, %b: index, %c: index) {
+  %0 = arith.muli %a, %b : index  
+  %1 = arith.muli %a, %c : index  
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%fun: !pdl.operation):
+  %muli = transform.structured.match ops{["arith.muli"]} in %fun
+  %h:2 = split_handles %muli in [2]
+  // expected-remark @below {{1}}
+  transform.test_print_number_of_associated_payload_ir_ops %h#0
+  %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun
+  // expected-error @below {{expected to contain 3 operation handles but it only contains 2 handles}}
+  %h_2:3 = split_handles %muli_2 in [3]
+}
+
+// -----
+
+func.func @split_handles(%a: index, %b: index, %c: index) {
+  %0 = arith.muli %a, %b : index  
+  %1 = arith.muli %a, %c : index  
+  return
+}
+
+transform.sequence failures(suppress) {
+^bb1(%fun: !pdl.operation):
+  %muli = transform.structured.match ops{["arith.muli"]} in %fun
+  %h:2 = split_handles %muli in [2]
+  // expected-remark @below {{1}}
+  transform.test_print_number_of_associated_payload_ir_ops %h#0
+  %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun
+  // Silenceable failure and all handles are now empty.
+  %h_2:3 = split_handles %muli_2 in [3]
+  // expected-remark @below {{0}}
+ transform.test_print_number_of_associated_payload_ir_ops %h_2#0
+}
index 752aaff..3994f0e 100644 (file)
@@ -292,6 +292,8 @@ mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
 DiagnosedSilenceableFailure
 mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply(
     transform::TransformResults &results, transform::TransformState &state) {
+  if (!getHandle())
+    emitRemark() << 0;
   emitRemark() << state.getPayloadOps(getHandle()).size();
   return DiagnosedSilenceableFailure::success();
 }