The transform.split_handles op is useful for ensuring a statically known number of operations are
tracked by the source `handle` and to extract them into individual handles
that can be further manipulated in isolation.
In the process of making the op robust wrt to silenceable errors and the suppress mode, issues were
uncovered and fixed.
The main issue was that silenceable errors were short-circuited too early and the payloads were not
set. This resulted in suppressed silenceable errors not propagating correctly.
Fixing the issue triggered a few test failures: silenceable error returns now must properly set the results state.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D135426
let hasFolder = 1;
}
+def SplitHandlesOp : TransformDialectOp<"split_handles",
+ [FunctionalStyleTransformOpTrait,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let summary = "Splits handles from a union of payload ops to a list";
+ let description = [{
+ Creates `num_result_handles` transform IR handles extracted from the
+ `handle` operand. The resulting Payload IR operation handles are listed
+ in the same order as the operations appear in the source `handle`.
+ This is useful for ensuring a statically known number of operations are
+ tracked by the source `handle` and to extract them into individual handles
+ that can be further manipulated in isolation.
+
+ This operation succeeds and returns `num_result_handles` if the statically
+ specified `num_result_handles` corresponds to the dynamic number of
+ operations contained in the source `handle`. Otherwise it silently fails.
+ }];
+
+ let arguments = (ins PDL_Operation:$handle,
+ I64Attr:$num_result_handles);
+ let results = (outs Variadic<PDL_Operation>:$results);
+ let assemblyFormat = [{
+ $handle `in` `[` $num_result_handles `]`
+ custom<StaticNumPDLResults>(type($results), ref($num_result_handles))
+ attr-dict
+ }];
+}
+
def PDLMatchOp : TransformDialectOp<"pdl_match",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
getOps()->getAsValueRange<StringAttr>().end());
ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
- if (payloadOps.size() != 1)
+ if (payloadOps.size() != 1) {
+ results.set(getResult().cast<OpResult>(), {});
return DiagnosedSilenceableFailure(
this->emitOpError("requires exactly one target handle"));
+ }
SmallVector<Operation *> res;
auto matchFun = [&](Operation *op) {
}
return OpFoldResult(op->getResult(0));
}));
- if (!diag.succeeded())
+ if (diag.isSilenceableFailure()) {
+ results.set(getFirst().cast<OpResult>(), {});
+ results.set(getSecond().cast<OpResult>(), {});
return diag;
+ }
if (splitPoints.size() != payload.size()) {
emitError() << "expected the dynamic split point handle to point to as "
if (!linalgOp) {
auto diag = emitSilenceableError() << "only applies to structured ops";
diag.attachNote(target->getLoc()) << "target op";
+ results.set(getFirst().cast<OpResult>(), {});
+ results.set(getSecond().cast<OpResult>(), {});
return diag;
}
auto diag = emitSilenceableError() << "dimension " << getDimension()
<< " does not exist in target op";
diag.attachNote(target->getLoc()) << "target op";
+ results.set(getFirst().cast<OpResult>(), {});
+ results.set(getSecond().cast<OpResult>(), {});
return diag;
}
<< scf::ForOp::getOperationName()
<< "' parent";
diag.attachNote(target->getLoc()) << "target op";
+ results.set(getResult().cast<OpResult>(), {});
return diag;
}
current = loop;
DiagnosedSilenceableFailure diag = emitSilenceableError()
<< "failed to outline";
diag.attachNote(target->getLoc()) << "target op";
+ results.set(getTransformed().cast<OpResult>(), {});
return diag;
}
func::CallOp call;
}
transform::TransformResults results(transform->getNumResults());
+ // Compute the result but do not short-circuit the silenceable failure case as
+ // we still want the handles to propagate properly so the "suppress" mode can
+ // proceed on a best effort basis.
DiagnosedSilenceableFailure result(transform.apply(results, *this));
- if (!result.succeeded())
+ if (result.isDefiniteFailure())
return result;
// Remove the mapping for the operand if it is consumed by the operation. This
DBGS() << "Top-level payload:\n";
getTopLevel()->print(llvm::dbgs());
});
- return DiagnosedSilenceableFailure::success();
+ return result;
}
//===----------------------------------------------------------------------===//
using namespace mlir;
+/// Custom parser for ReplicateOp.
static ParseResult parsePDLOpTypedResults(
OpAsmParser &parser, SmallVectorImpl<Type> &types,
const SmallVectorImpl<OpAsmParser::UnresolvedOperand> &handles) {
return success();
}
+/// Custom printer for ReplicateOp.
static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange,
ValueRange) {}
+/// Custom parser for SplitHandlesOp.
+static ParseResult parseStaticNumPDLResults(OpAsmParser &parser,
+ SmallVectorImpl<Type> &types,
+ IntegerAttr numHandlesAttr) {
+ types.resize(numHandlesAttr.getInt(),
+ pdl::OperationType::get(parser.getContext()));
+ return success();
+}
+
+/// Custom printer for SplitHandlesOp.
+static void printStaticNumPDLResults(OpAsmPrinter &, Operation *, TypeRange,
+ IntegerAttr) {}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
}
//===----------------------------------------------------------------------===//
+// SplitHandlesOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::SplitHandlesOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ int64_t numResultHandles =
+ getHandle() ? state.getPayloadOps(getHandle()).size() : 0;
+ int64_t expectedNumResultHandles = getNumResultHandles();
+ if (numResultHandles != expectedNumResultHandles) {
+ // Failing case needs to propagate gracefully for both suppress and
+ // propagate modes.
+ for (int64_t idx = 0; idx < expectedNumResultHandles; ++idx)
+ results.set(getResults()[idx].cast<OpResult>(), {});
+ // Empty input handle corner case: always propagates empty handles in both
+ // suppress and propagate modes.
+ if (numResultHandles == 0)
+ return DiagnosedSilenceableFailure::success();
+ // If the input handle was not empty and the number of result handles does
+ // not match, this is a legit silenceable error.
+ return emitSilenceableError()
+ << getHandle() << " expected to contain " << expectedNumResultHandles
+ << " operation handles but it only contains " << numResultHandles
+ << " handles";
+ }
+ // Normal successful case.
+ for (auto en : llvm::enumerate(state.getPayloadOps(getHandle())))
+ results.set(getResults()[en.index()].cast<OpResult>(), en.value());
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SplitHandlesOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getHandle(), effects);
+ producesHandle(getResults(), effects);
+ // There are no effects on the Payload IR as this is only a handle
+ // manipulation.
+}
+
+//===----------------------------------------------------------------------===//
// PDLMatchOp
//===----------------------------------------------------------------------===//
}
+// -----
+
+func.func @split_handles(%a: index, %b: index, %c: index) {
+ %0 = arith.muli %a, %b : index
+ %1 = arith.muli %a, %c : index
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%fun: !pdl.operation):
+ %muli = transform.structured.match ops{["arith.muli"]} in %fun
+ %h:2 = split_handles %muli in [2]
+ // expected-remark @below {{1}}
+ transform.test_print_number_of_associated_payload_ir_ops %h#0
+ %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun
+ // expected-error @below {{expected to contain 3 operation handles but it only contains 2 handles}}
+ %h_2:3 = split_handles %muli_2 in [3]
+}
+
+// -----
+
+func.func @split_handles(%a: index, %b: index, %c: index) {
+ %0 = arith.muli %a, %b : index
+ %1 = arith.muli %a, %c : index
+ return
+}
+
+transform.sequence failures(suppress) {
+^bb1(%fun: !pdl.operation):
+ %muli = transform.structured.match ops{["arith.muli"]} in %fun
+ %h:2 = split_handles %muli in [2]
+ // expected-remark @below {{1}}
+ transform.test_print_number_of_associated_payload_ir_ops %h#0
+ %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun
+ // Silenceable failure and all handles are now empty.
+ %h_2:3 = split_handles %muli_2 in [3]
+ // expected-remark @below {{0}}
+ transform.test_print_number_of_associated_payload_ir_ops %h_2#0
+}
DiagnosedSilenceableFailure
mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply(
transform::TransformResults &results, transform::TransformState &state) {
+ if (!getHandle())
+ emitRemark() << 0;
emitRemark() << state.getPayloadOps(getHandle()).size();
return DiagnosedSilenceableFailure::success();
}