[mlir][transform] Add ForeachOp to transform dialect
authorMatthias Springer <springerm@google.com>
Tue, 26 Jul 2022 16:06:57 +0000 (18:06 +0200)
committerMatthias Springer <springerm@google.com>
Tue, 26 Jul 2022 16:07:44 +0000 (18:07 +0200)
This op "unbatches" an op handle and executes the loop body for each payload op.

Differential Revision: https://reviews.llvm.org/D130257

mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/ops-invalid.mlir
mlir/test/Dialect/Transform/ops.mlir
mlir/test/Dialect/Transform/test-interpreter.mlir

index 27a68b0..f11e3da 100644 (file)
@@ -64,6 +64,14 @@ def Transform_Dialect : Dialect {
     correspond to groups of outer and inner loops, respectively, produced by
     the tiling transformation.
 
+    A Transform IR value such as `%0` may be associated with multiple payload
+    operations. This is conceptually a set of operations and no assumptions
+    should be made about the order of ops. Most Transform IR ops support
+    operand values that are mapped to multiple operations. They usually apply
+    the respective transformation for every mapped op ("batched execution").
+    Deviations from this convention are described in the documentation of
+    Transform IR ops.
+
     Overall, Transform IR ops are expected to be contained in a single top-level
     op. Such top-level ops specify how to apply the transformations described
     by the operations they contain, e.g., `transform.sequence` executes
index d578c15..bc5fd01 100644 (file)
@@ -95,6 +95,46 @@ def AlternativesOp : TransformDialectOp<"alternatives",
   let hasVerifier = 1;
 }
 
+def ForeachOp : TransformDialectOp<"foreach",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     DeclareOpInterfaceMethods<RegionBranchOpInterface, [
+         "getSuccessorRegions", "getSuccessorEntryOperands"]>,
+     SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">
+    ]> {
+  let summary = "Executes the body for each payload op";
+  let description = [{
+    This op has exactly one region with exactly one block ("body"). The body is
+    executed for each payload op that is associated to the target operand in an
+    unbatched fashion. I.e., the block argument ("iteration variable") is always
+    mapped to exactly one payload op.
+
+    This op always reads the target handle. Furthermore, it consumes the handle
+    if there is a transform op in the body that consumes the iteration variable.
+    This op does not return anything.
+
+    The transformations inside the body are applied in order of their
+    appearance. During application, if any transformation in the sequence fails,
+    the entire sequence fails immediately leaving the payload IR in potentially
+    invalid state, i.e., this operation offers no transformation rollback
+    capabilities.
+  }];
+
+  let arguments = (ins PDL_Operation:$target);
+  let results = (outs);
+  let regions = (region SizedRegion<1>:$body);
+  let assemblyFormat = "$target $body attr-dict";
+
+  let extraClassDeclaration = [{
+    /// Allow the dialect prefix to be omitted.
+    static StringRef getDefaultDialect() { return "transform"; }
+
+    BlockArgument getIterationVariable() {
+      return getBody().front().getArgument(0);
+    }
+  }];
+}
+
 def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
      NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
index bc2d971..5d702d9 100644 (file)
@@ -274,6 +274,64 @@ LogicalResult transform::AlternativesOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
+// ForeachOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ForeachOp::apply(transform::TransformResults &results,
+                            transform::TransformState &state) {
+  ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
+  for (Operation *op : payloadOps) {
+    auto scope = state.make_region_scope(getBody());
+    if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
+      return DiagnosedSilenceableFailure::definiteFailure();
+
+    for (Operation &transform : getBody().front().without_terminator()) {
+      DiagnosedSilenceableFailure result = state.applyTransform(
+          cast<transform::TransformOpInterface>(transform));
+      if (!result.succeeded())
+        return result;
+    }
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::ForeachOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  BlockArgument iterVar = getIterationVariable();
+  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+        return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
+      })) {
+    consumesHandle(getTarget(), effects);
+  } else {
+    onlyReadsHandle(getTarget(), effects);
+  }
+}
+
+void transform::ForeachOp::getSuccessorRegions(
+    Optional<unsigned> index, ArrayRef<Attribute> operands,
+    SmallVectorImpl<RegionSuccessor> &regions) {
+  Region *bodyRegion = &getBody();
+  if (!index) {
+    regions.emplace_back(bodyRegion, bodyRegion->getArguments());
+    return;
+  }
+
+  // Branch back to the region or the parent.
+  assert(*index == 0 && "unexpected region index");
+  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
+  regions.emplace_back();
+}
+
+OperandRange
+transform::ForeachOp::getSuccessorEntryOperands(Optional<unsigned> index) {
+  // The iteration variable op handle is mapped to a subset (one op to be
+  // precise) of the payload ops of the ForeachOp operand.
+  assert(index && *index == 0 && "unexpected region index");
+  return getOperation()->getOperands();
+}
+
+//===----------------------------------------------------------------------===//
 // GetClosestIsolatedParentOp
 //===----------------------------------------------------------------------===//
 
index b76bd07..2650651 100644 (file)
@@ -184,3 +184,18 @@ transform.alternatives {
 ^bb0:
   transform.yield
 }
+
+// -----
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  // expected-error @below {{result #0 has more than one potential consumer}}
+  %0 = test_produce_param_or_forward_operand 42
+  // expected-note @below {{used here as operand #0}}
+  transform.foreach %0 {
+  ^bb1(%arg1: !pdl.operation):
+    transform.test_consume_operand %arg1
+  }
+  // expected-note @below {{used here as operand #0}}
+  transform.test_consume_operand %0
+}
index e9e99de..23dd6b8 100644 (file)
@@ -49,3 +49,12 @@ transform.sequence {
   ^bb3(%arg3: !pdl.operation):
   }
 }
+
+// CHECK: transform.sequence
+// CHECK: foreach
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  transform.foreach %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+  }
+}
index 1f0dd40..ed3a250 100644 (file)
@@ -597,3 +597,33 @@ module {
     }
   }
 }
+
+// -----
+
+func.func @bar() {
+  // expected-remark @below {{transform applied}}
+  %0 = arith.constant 0 : i32
+  // expected-remark @below {{transform applied}}
+  %1 = arith.constant 1 : i32
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @const : benefit(1) {
+    %r = pdl.types
+    %0 = pdl.operation "arith.constant" -> (%r : !pdl.range<type>)
+    pdl.rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %f = pdl_match @const in %arg1
+    transform.foreach %f {
+    ^bb2(%arg2: !pdl.operation):
+      // expected-remark @below {{1}}
+      transform.test_print_number_of_associated_payload_ir_ops %arg2
+      transform.test_print_remark_at_operand %arg2, "transform applied"
+    }
+  }
+}