[mlir] Transform dialect: introduce merge_handles op
authorAlex Zinenko <zinenko@google.com>
Thu, 7 Jul 2022 11:11:34 +0000 (13:11 +0200)
committerAlex Zinenko <zinenko@google.com>
Thu, 7 Jul 2022 11:19:46 +0000 (13:19 +0200)
This Transform dialect op allows one to merge the lists of Payload IR
operations pointed to by several handles into a single list associated with one
handle. This is an important Transform dialect usability improvement for cases
where transformations may temporarily diverge for different groups of Payload
IR ops before converging back to the same script. Without this op, several
copies of the trailing transformations would have to be present in the
transformation script.

Depends On D129090

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D129110

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/python/mlir/dialects/_transform_ops_ext.py
mlir/test/Dialect/Transform/test-interpreter.mlir
mlir/test/python/dialects/transform.py

index dd0e82e4ada7923324c74894b3814604411b3ac1..ee86c888fb7bed32c4c8c61d6cf86e79c0097a1a 100644 (file)
@@ -121,6 +121,28 @@ def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent
   let assemblyFormat = "$target attr-dict";
 }
 
+def MergeHandlesOp : TransformDialectOp<"merge_handles",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let summary = "Merges handles into one pointing to the union of payload ops";
+  let description = [{
+    Creates a new Transform IR handle value that points to the same Payload IR
+    operations as the operand handles. The Payload IR operations are listed
+    in the same order as they are in the operand handles, grouped by operand
+    handle, e.g., all Payload IR operations associated with the first handle
+    come first, then all Payload IR operations associated with the second handle
+    and so on. If `deduplicate` is set, do not add the given Payload IR
+    operation more than once to the final list regardless of it coming from the
+    same or different handles. Consumes the operands and produces a new handle.
+  }];
+
+  let arguments = (ins Variadic<PDL_Operation>:$handles,
+                       UnitAttr:$deduplicate);
+  let results = (outs PDL_Operation:$result);
+  let assemblyFormat = "($deduplicate^)? $handles attr-dict";
+  let hasFolder = 1;
+}
+
 def PDLMatchOp : TransformDialectOp<"pdl_match",
     [DeclareOpInterfaceMethods<TransformOpInterface>]> {
   let summary = "Finds ops that match the named PDL pattern";
index d83a9de96be9d42f5d67d265bf3d2ecf1677dd7c..1ff58853136d605a2562271bd408de3c5b9d9264 100644 (file)
@@ -286,6 +286,52 @@ DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// MergeHandlesOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::MergeHandlesOp::apply(transform::TransformResults &results,
+                                 transform::TransformState &state) {
+  SmallVector<Operation *> operations;
+  for (Value operand : getHandles())
+    llvm::append_range(operations, state.getPayloadOps(operand));
+  if (!getDeduplicate()) {
+    results.set(getResult().cast<OpResult>(), operations);
+    return DiagnosedSilenceableFailure::success();
+  }
+
+  SetVector<Operation *> uniqued(operations.begin(), operations.end());
+  results.set(getResult().cast<OpResult>(), uniqued.getArrayRef());
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::MergeHandlesOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  for (Value operand : getHandles()) {
+    effects.emplace_back(MemoryEffects::Read::get(), operand,
+                         transform::TransformMappingResource::get());
+    effects.emplace_back(MemoryEffects::Free::get(), operand,
+                         transform::TransformMappingResource::get());
+  }
+  effects.emplace_back(MemoryEffects::Allocate::get(), getResult(),
+                       transform::TransformMappingResource::get());
+  effects.emplace_back(MemoryEffects::Write::get(), getResult(),
+                       transform::TransformMappingResource::get());
+
+  // There are no effects on the Payload IR as this is only a handle
+  // manipulation.
+}
+
+OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) {
+  if (getDeduplicate() || getHandles().size() != 1)
+    return {};
+
+  // If deduplication is not required and there is only one operand, it can be
+  // used directly instead of merging.
+  return getHandles().front();
+}
+
 //===----------------------------------------------------------------------===//
 // PDLMatchOp
 //===----------------------------------------------------------------------===//
index 138195dca73e806fcffc19d0718a7dd751afebb8..ca45ab7e281762868fc9f856f08f12301a315e48 100644 (file)
@@ -28,6 +28,21 @@ class GetClosestIsolatedParentOp:
         ip=ip)
 
 
+class MergeHandlesOp:
+
+  def __init__(self,
+               handles: Sequence[Union[Operation, Value]],
+               *,
+               deduplicate: bool = False,
+               loc=None,
+               ip=None):
+    super().__init__(
+        pdl.OperationType.get(), [_get_op_result_or_value(h) for h in handles],
+        deduplicate=deduplicate,
+        loc=loc,
+        ip=ip)
+
+
 class PDLMatchOp:
 
   def __init__(self,
index 34d1fc8a2b174549dd48df59b3ffb32275be6a6c..e2fe60754ad81c1f0936d08ec6d9c0b8cd3ca0cb 100644 (file)
@@ -460,3 +460,42 @@ transform.with_pdl_patterns {
     %1:2 = transform.test_correct_number_of_multi_results %0
   }
 }
+
+// -----
+
+// Expecting to match all operations by merging the handles that matched addi
+// and subi separately.
+func.func @foo(%arg0: index) {
+  // expected-remark @below {{matched}}
+  %0 = arith.addi %arg0, %arg0 : index
+  // expected-remark @below {{matched}}
+  %1 = arith.subi %arg0, %arg0 : index
+  // expected-remark @below {{matched}}
+  %2 = arith.addi %0, %1 : index
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @addi : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "arith.addi"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+  pdl.pattern @subi : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "arith.subi"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = pdl_match @addi in %arg1
+    %1 = pdl_match @subi in %arg1
+    %2 = merge_handles %0, %1
+    test_print_remark_at_operand %2, "matched"
+  }
+}
+
index 472201731ba51363d550e96cab3bc86afad30ee6..21392ca7d5e32761166927b5d86ecd3ceab264f8 100644 (file)
@@ -82,3 +82,15 @@ def testGetClosestIsolatedParentOp():
   # CHECK: transform.sequence
   # CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
   # CHECK:   = get_closest_isolated_parent %[[ARG1]]
+
+
+@run
+def testMergeHandlesOp():
+  sequence = transform.SequenceOp()
+  with InsertionPoint(sequence.body):
+    transform.MergeHandlesOp([sequence.bodyTarget])
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testMergeHandlesOp
+  # CHECK: transform.sequence
+  # CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
+  # CHECK:   = merge_handles %[[ARG1]]