let assemblyFormat = "$target attr-dict";
}
+def MergeHandlesOp : TransformDialectOp<"merge_handles",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let summary = "Merges handles into one pointing to the union of payload ops";
+ let description = [{
+ Creates a new Transform IR handle value that points to the same Payload IR
+ operations as the operand handles. The Payload IR operations are listed
+ in the same order as they are in the operand handles, grouped by operand
+ handle, e.g., all Payload IR operations associated with the first handle
+ come first, then all Payload IR operations associated with the second handle
+ and so on. If `deduplicate` is set, do not add the given Payload IR
+ operation more than once to the final list regardless of it coming from the
+ same or different handles. Consumes the operands and produces a new handle.
+ }];
+
+ let arguments = (ins Variadic<PDL_Operation>:$handles,
+ UnitAttr:$deduplicate);
+ let results = (outs PDL_Operation:$result);
+ let assemblyFormat = "($deduplicate^)? $handles attr-dict";
+ let hasFolder = 1;
+}
+
def PDLMatchOp : TransformDialectOp<"pdl_match",
[DeclareOpInterfaceMethods<TransformOpInterface>]> {
let summary = "Finds ops that match the named PDL pattern";
}
//===----------------------------------------------------------------------===//
+// MergeHandlesOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::MergeHandlesOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ SmallVector<Operation *> operations;
+ for (Value operand : getHandles())
+ llvm::append_range(operations, state.getPayloadOps(operand));
+ if (!getDeduplicate()) {
+ results.set(getResult().cast<OpResult>(), operations);
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ SetVector<Operation *> uniqued(operations.begin(), operations.end());
+ results.set(getResult().cast<OpResult>(), uniqued.getArrayRef());
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::MergeHandlesOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ for (Value operand : getHandles()) {
+ effects.emplace_back(MemoryEffects::Read::get(), operand,
+ transform::TransformMappingResource::get());
+ effects.emplace_back(MemoryEffects::Free::get(), operand,
+ transform::TransformMappingResource::get());
+ }
+ effects.emplace_back(MemoryEffects::Allocate::get(), getResult(),
+ transform::TransformMappingResource::get());
+ effects.emplace_back(MemoryEffects::Write::get(), getResult(),
+ transform::TransformMappingResource::get());
+
+ // There are no effects on the Payload IR as this is only a handle
+ // manipulation.
+}
+
+OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) {
+ if (getDeduplicate() || getHandles().size() != 1)
+ return {};
+
+ // If deduplication is not required and there is only one operand, it can be
+ // used directly instead of merging.
+ return getHandles().front();
+}
+
+//===----------------------------------------------------------------------===//
// PDLMatchOp
//===----------------------------------------------------------------------===//
ip=ip)
+class MergeHandlesOp:
+
+ def __init__(self,
+ handles: Sequence[Union[Operation, Value]],
+ *,
+ deduplicate: bool = False,
+ loc=None,
+ ip=None):
+ super().__init__(
+ pdl.OperationType.get(), [_get_op_result_or_value(h) for h in handles],
+ deduplicate=deduplicate,
+ loc=loc,
+ ip=ip)
+
+
class PDLMatchOp:
def __init__(self,
%1:2 = transform.test_correct_number_of_multi_results %0
}
}
+
+// -----
+
+// Expecting to match all operations by merging the handles that matched addi
+// and subi separately.
+func.func @foo(%arg0: index) {
+ // expected-remark @below {{matched}}
+ %0 = arith.addi %arg0, %arg0 : index
+ // expected-remark @below {{matched}}
+ %1 = arith.subi %arg0, %arg0 : index
+ // expected-remark @below {{matched}}
+ %2 = arith.addi %0, %1 : index
+ return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @addi : benefit(1) {
+ %0 = pdl.operands
+ %1 = pdl.types
+ %2 = pdl.operation "arith.addi"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ pdl.rewrite %2 with "transform.dialect"
+ }
+ pdl.pattern @subi : benefit(1) {
+ %0 = pdl.operands
+ %1 = pdl.types
+ %2 = pdl.operation "arith.subi"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ pdl.rewrite %2 with "transform.dialect"
+ }
+
+ transform.sequence %arg0 {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = pdl_match @addi in %arg1
+ %1 = pdl_match @subi in %arg1
+ %2 = merge_handles %0, %1
+ test_print_remark_at_operand %2, "matched"
+ }
+}
+
# CHECK: transform.sequence
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
# CHECK: = get_closest_isolated_parent %[[ARG1]]
+
+
+@run
+def testMergeHandlesOp():
+ sequence = transform.SequenceOp()
+ with InsertionPoint(sequence.body):
+ transform.MergeHandlesOp([sequence.bodyTarget])
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testMergeHandlesOp
+ # CHECK: transform.sequence
+ # CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
+ # CHECK: = merge_handles %[[ARG1]]