[mlir] multi-argument binding for top-level transform ops
authorAlex Zinenko <zinenko@google.com>
Wed, 25 Jan 2023 16:53:25 +0000 (16:53 +0000)
committerAlex Zinenko <zinenko@google.com>
Tue, 31 Jan 2023 14:21:28 +0000 (14:21 +0000)
`applyTransforms` now takes an optional mapping to be associated with
trailing block arguments of the top-level transform op, in addition to
the payload root. This allows for more advanced forms of communication
between C++ code and the transform dialect interpreter, in particular
supplying operations without having to re-match them during
interpretation.

Reviewed By: shabalin

Differential Revision: https://reviews.llvm.org/D142559

13 files changed:
mlir/docs/Dialects/Transform.md
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/python/mlir/dialects/_transform_ops_ext.py
mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir [new file with mode: 0644]
mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir [new file with mode: 0644]
mlir/test/Dialect/Transform/ops-invalid.mlir
mlir/test/Dialect/Transform/ops.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
mlir/test/python/dialects/transform.py

index eb86bdc..e73bb79 100644 (file)
@@ -109,13 +109,19 @@ A program transformation expressed using the Transform dialect can be
 programmatically triggered by calling:
 
 ```c++
-LogicalResult transform::applyTransforms(Operation *payloadRoot,
-                                         TransformOpInterface transform,
-                                         const TransformOptions &options);
+LogicalResult transform::applyTransforms(
+    Operation *payloadRoot,
+    ArrayRef<ArrayRef<PointerUnion<Operation *, Attribute>> extraMappings,
+    TransformOpInterface transform,
+    const TransformOptions &options);
 ```
 
 that applies the transformations specified by the top-level `transform` to
-payload IR contained in `payloadRoot`.
+payload IR contained in `payloadRoot`. The payload root operation will be
+associated with the first argument of the entry block of the top-level transform
+op. This block may have additional arguments, handles or parameters. They will
+be associated with values provided as `extraMappings`. The call will report an
+error and return if the wrong number of mappings is provided.
 
 ## Dialect Extension Mechanism
 
index e523c08..063b6de 100644 (file)
@@ -42,6 +42,9 @@ private:
   bool expensiveChecksEnabled = true;
 };
 
+using Param = Attribute;
+using MappedValue = llvm::PointerUnion<Operation *, Param>;
+
 /// Entry point to the Transform dialect infrastructure. Applies the
 /// transformation specified by `transform` to payload IR contained in
 /// `payloadRoot`. The `transform` operation may contain other operations that
@@ -50,6 +53,7 @@ private:
 /// This function internally keeps track of the transformation state.
 LogicalResult
 applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
+                ArrayRef<ArrayRef<MappedValue>> extraMapping = {},
                 const TransformOptions &options = TransformOptions());
 
 /// The state maintained across applications of various ops implementing the
@@ -85,7 +89,7 @@ applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
 /// using `mapBlockArguments`.
 class TransformState {
 public:
-  using Param = Attribute;
+  using Param = transform::Param;
 
 private:
   /// Mapping between a Value in the transform IR and the corresponding set of
@@ -109,15 +113,23 @@ private:
     ParamMapping params;
   };
 
-  friend LogicalResult applyTransforms(Operation *payloadRoot,
-                                       TransformOpInterface transform,
-                                       const TransformOptions &options);
+  friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
+                                       ArrayRef<ArrayRef<MappedValue>>,
+                                       const TransformOptions &);
 
 public:
   /// Returns the op at which the transformation state is rooted. This is
   /// typically helpful for transformations that apply globally.
   Operation *getTopLevel() const;
 
+  /// Returns the number of extra mappings for the top-level operation.
+  size_t getNumTopLevelMappings() const { return topLevelMappedValues.size(); }
+
+  /// Returns the position-th extra mapping for the top-level operation.
+  ArrayRef<MappedValue> getTopLevelMapping(size_t position) const {
+    return topLevelMappedValues[position];
+  }
+
   /// Returns the list of ops that the given transform IR value corresponds to.
   /// This is helpful for transformations that apply to a particular handle.
   ArrayRef<Operation *> getPayloadOps(Value value) const;
@@ -150,6 +162,8 @@ public:
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
     return setPayloadOps(argument, operations);
   }
+  LogicalResult mapBlockArgument(BlockArgument argument,
+                                 ArrayRef<MappedValue> values);
 
   // Forward declarations to support limited visibility.
   class RegionScope;
@@ -302,6 +316,7 @@ private:
   /// which may or may not contain the region with transform ops. Additional
   /// options can be provided through the trailing configuration object.
   TransformState(Region *region, Operation *payloadRoot,
+                 ArrayRef<ArrayRef<MappedValue>> extraMappings = {},
                  const TransformOptions &options = TransformOptions());
 
   /// Returns the mappings frame for the reigon in which the value is defined.
@@ -403,6 +418,15 @@ private:
   /// The top-level operation that contains all payload IR, typically a module.
   Operation *topLevel;
 
+  /// Storage for extra mapped values (payload operations or parameters) to be
+  /// associated with additional entry block arguments of the top-level
+  /// transform operation. Each entry in `topLevelMappedValues` is a reference
+  /// to a contiguous block in `topLevelMappedValueStorage`.
+  // TODO: turn this into a proper named data structure, there are several more
+  // below.
+  SmallVector<ArrayRef<MappedValue>> topLevelMappedValues;
+  SmallVector<MappedValue> topLevelMappedValueStorage;
+
   /// Additional options controlling the transformation state behavior.
   TransformOptions options;
 
index 0a737d5..7eb9a01 100644 (file)
@@ -26,6 +26,9 @@ class FailurePropagationModeAttr;
 /// A builder function that populates the body of a SequenceOp.
 using SequenceBodyBuilderFn = ::llvm::function_ref<void(
     ::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument)>;
+using SequenceBodyBuilderArgsFn =
+    ::llvm::function_ref<void(::mlir::OpBuilder &, ::mlir::Location,
+                              ::mlir::BlockArgument, ::mlir::ValueRange)>;
 } // namespace transform
 } // namespace mlir
 
index 4bb6700..6f3b4cf 100644 (file)
@@ -384,7 +384,8 @@ def SequenceOp : TransformDialectOp<"sequence",
      DeclareOpInterfaceMethods<TransformOpInterface>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      OpAsmOpInterface, PossibleTopLevelTransformOpTrait,
-     SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
+     SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">,
+     AttrSizedOperandSegments]> {
   let summary = "Contains a sequence of other transform ops to apply";
   let description = [{
     The transformations indicated by the sequence are applied in order of their
@@ -417,12 +418,14 @@ def SequenceOp : TransformDialectOp<"sequence",
   }];
 
   let arguments = (ins FailurePropagationMode:$failure_propagation_mode,
-                       Optional<TransformHandleTypeInterface>:$root);
+                       Optional<TransformHandleTypeInterface>:$root,
+                       Variadic<Transform_AnyHandleOrParamType>:$extra_bindings);
   let results = (outs Variadic<TransformHandleTypeInterface>:$results);
   let regions = (region SizedRegion<1>:$body);
 
   let assemblyFormat =
-    "($root^ `:` type($root))? (`->` type($results)^)? `failures` `(` "
+    "custom<SequenceOpOperands>($root, type($root), $extra_bindings, type($extra_bindings))"
+    " (`->` type($results)^)? `failures` `(` "
     "$failure_propagation_mode `)` attr-dict-with-keyword regions";
 
   let builders = [
@@ -432,11 +435,25 @@ def SequenceOp : TransformDialectOp<"sequence",
         "::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
         "::mlir::Value":$root, "SequenceBodyBuilderFn":$bodyBuilder)>,
 
-    // Build a sequence without a root but a certain bbArg type.
+    // Build a sequence with a root and additional arguments.
+    OpBuilder<(ins
+        "::mlir::TypeRange":$resultTypes,
+        "::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
+        "::mlir::Value":$root, "::mlir::ValueRange":$extraBindings,
+        "SequenceBodyBuilderArgsFn":$bodyBuilder)>,
+
+    // Build a top-level sequence (no root).
+    OpBuilder<(ins
+        "::mlir::TypeRange":$resultTypes,
+        "::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
+        "::mlir::Type":$bbArgType, "SequenceBodyBuilderFn":$bodyBuilder)>,
+
+    // Build a top-level sequence (no root) with extra arguments.
     OpBuilder<(ins
         "::mlir::TypeRange":$resultTypes,
         "::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
-        "::mlir::Type":$bbArgType, "SequenceBodyBuilderFn":$bodyBuilder)>
+        "::mlir::Type":$bbArgType, "::mlir::TypeRange":$extraBindingTypes,
+        "SequenceBodyBuilderArgsFn":$bodyBuilder)>
   ];
 
   let extraClassDeclaration = [{
index e2ab48e..5ecc1f4 100644 (file)
@@ -27,10 +27,20 @@ using namespace mlir;
 
 constexpr const Value transform::TransformState::kTopLevelValue;
 
-transform::TransformState::TransformState(Region *region,
-                                          Operation *payloadRoot,
-                                          const TransformOptions &options)
+transform::TransformState::TransformState(
+    Region *region, Operation *payloadRoot,
+    ArrayRef<ArrayRef<MappedValue>> extraMappings,
+    const TransformOptions &options)
     : topLevel(payloadRoot), options(options) {
+  topLevelMappedValues.reserve(extraMappings.size());
+  for (ArrayRef<MappedValue> mapping : extraMappings) {
+    size_t start = topLevelMappedValueStorage.size();
+    llvm::append_range(topLevelMappedValueStorage, mapping);
+    topLevelMappedValues.push_back(
+        ArrayRef<MappedValue>(topLevelMappedValueStorage)
+            .slice(start, mapping.size()));
+  }
+
   auto result = mappings.try_emplace(region);
   assert(result.second && "the region scope is already present");
   (void)result;
@@ -73,6 +83,38 @@ LogicalResult transform::TransformState::getHandlesForPayloadOp(
 }
 
 LogicalResult
+transform::TransformState::mapBlockArgument(BlockArgument argument,
+                                            ArrayRef<MappedValue> values) {
+  if (argument.getType().isa<TransformHandleTypeInterface>()) {
+    SmallVector<Operation *> operations;
+    operations.reserve(values.size());
+    for (MappedValue value : values) {
+      if (auto *op = value.dyn_cast<Operation *>()) {
+        operations.push_back(op);
+        continue;
+      }
+      return emitError(argument.getLoc())
+             << "wrong kind of value provided for top-level operation handle";
+    }
+    return setPayloadOps(argument, operations);
+  }
+
+  assert(argument.getType().isa<TransformParamTypeInterface>() &&
+         "unsupported kind of block argument");
+  SmallVector<Param> parameters;
+  parameters.reserve(values.size());
+  for (MappedValue value : values) {
+    if (auto attr = value.dyn_cast<Attribute>()) {
+      parameters.push_back(attr);
+      continue;
+    }
+    return emitError(argument.getLoc())
+           << "wrong kind of value provided for top-level parameter";
+  }
+  return setParams(argument, parameters);
+}
+
+LogicalResult
 transform::TransformState::setPayloadOps(Value value,
                                          ArrayRef<Operation *> targets) {
   assert(value != kTopLevelValue &&
@@ -522,12 +564,43 @@ void transform::detail::setApplyToOneResults(
 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
     TransformState &state, Operation *op, Region &region) {
   SmallVector<Operation *> targets;
-  if (op->getNumOperands() != 0)
+  SmallVector<SmallVector<MappedValue>> extraMappings;
+  if (op->getNumOperands() != 0) {
     llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
-  else
+    for (Value operand : op->getOperands().drop_front()) {
+      SmallVector<MappedValue> &mapped = extraMappings.emplace_back();
+      if (operand.getType().isa<TransformHandleTypeInterface>()) {
+        llvm::append_range(mapped, state.getPayloadOps(operand));
+      } else {
+        assert(operand.getType().isa<TransformParamTypeInterface>() &&
+               "unsupported kind of transform dialect value");
+        llvm::append_range(mapped, state.getParams(operand));
+      }
+    }
+  } else {
+    if (state.getNumTopLevelMappings() !=
+        region.front().getNumArguments() - 1) {
+      return emitError(op->getLoc())
+             << "operation expects " << region.front().getNumArguments() - 1
+             << " extra value bindings, but " << state.getNumTopLevelMappings()
+             << " were provided to the interpreter";
+    }
+
     targets.push_back(state.getTopLevel());
+    for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
+      extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
+  }
+
+  if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
+    return failure();
+
+  for (BlockArgument argument : region.front().getArguments().drop_front()) {
+    if (failed(state.mapBlockArgument(
+            argument, extraMappings[argument.getArgNumber() - 1])))
+      return failure();
+  }
 
-  return state.mapBlockArguments(region.front().getArgument(0), targets);
+  return success();
 }
 
 LogicalResult
@@ -547,19 +620,42 @@ transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
     return op->emitOpError() << "expects a single-block region";
 
   Block *body = &bodyRegion->front();
-  if (body->getNumArguments() != 1 ||
-      !body->getArgumentTypes()[0].isa<TransformHandleTypeInterface>()) {
+  if (body->getNumArguments() == 0) {
+    return op->emitOpError()
+           << "expects the entry block to have at least one argument";
+  }
+  if (!body->getArgument(0).getType().isa<TransformHandleTypeInterface>()) {
     return op->emitOpError()
-           << "expects the entry block to have one argument "
-              "of type implementing TransformHandleTypeInterface";
+           << "expects the first entry block argument to be of type "
+              "implementing TransformHandleTypeInterface";
+  }
+  BlockArgument arg = body->getArgument(0);
+  if (op->getNumOperands() != 0) {
+    if (arg.getType() != op->getOperand(0).getType()) {
+      return op->emitOpError()
+             << "expects the type of the block argument to match "
+                "the type of the operand";
+    }
+  }
+  for (BlockArgument arg : body->getArguments().drop_front()) {
+    if (arg.getType()
+            .isa<TransformHandleTypeInterface, TransformParamTypeInterface>())
+      continue;
+
+    InFlightDiagnostic diag =
+        op->emitOpError()
+        << "expects trailing entry block arguments to be of type implementing "
+           "TransformHandleTypeInterface or TransformParamTypeInterface";
+    diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
+    return diag;
   }
 
   if (auto *parent =
           op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
-    if (op->getNumOperands() == 0) {
+    if (op->getNumOperands() != body->getNumArguments()) {
       InFlightDiagnostic diag =
           op->emitOpError()
-          << "expects the root operation to be provided for a nested op";
+          << "expects operands to be provided for a nested op";
       diag.attachNote(parent->getLoc())
           << "nested in another possible top-level op";
       return diag;
@@ -717,9 +813,11 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
 // Entry point.
 //===----------------------------------------------------------------------===//
 
-LogicalResult transform::applyTransforms(Operation *payloadRoot,
-                                         TransformOpInterface transform,
-                                         const TransformOptions &options) {
+LogicalResult
+transform::applyTransforms(Operation *payloadRoot,
+                           TransformOpInterface transform,
+                           ArrayRef<ArrayRef<MappedValue>> extraMapping,
+                           const TransformOptions &options) {
 #ifndef NDEBUG
   if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
       transform->getNumOperands() != 0) {
@@ -730,7 +828,8 @@ LogicalResult transform::applyTransforms(Operation *payloadRoot,
   }
 #endif // NDEBUG
 
-  TransformState state(transform->getParentRegion(), payloadRoot, options);
+  TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
+                       options);
   return state.applyTransform(transform).checkAndReport();
 }
 
index 8bd5cab..0314932 100644 (file)
 
 using namespace mlir;
 
+static ParseResult parseSequenceOpOperands(
+    OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &root,
+    Type &rootType,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
+    SmallVectorImpl<Type> &extraBindingTypes);
+static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
+                                    Value root, Type rootType,
+                                    ValueRange extraBindings,
+                                    TypeRange extraBindingTypes);
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
 
@@ -654,6 +664,76 @@ transform::SequenceOp::apply(transform::TransformResults &results,
   return DiagnosedSilenceableFailure::success();
 }
 
+static ParseResult parseSequenceOpOperands(
+    OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &root,
+    Type &rootType,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
+    SmallVectorImpl<Type> &extraBindingTypes) {
+  OpAsmParser::UnresolvedOperand rootOperand;
+  OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
+  if (!hasRoot.has_value()) {
+    root = std::nullopt;
+    return success();
+  }
+  if (failed(hasRoot.value()))
+    return failure();
+  root = rootOperand;
+
+  if (succeeded(parser.parseOptionalComma())) {
+    if (failed(parser.parseOperandList(extraBindings)))
+      return failure();
+  }
+  if (failed(parser.parseColon()))
+    return failure();
+
+  // The paren is truly optional.
+  (void)parser.parseOptionalLParen();
+
+  if (failed(parser.parseType(rootType))) {
+    return failure();
+  }
+
+  if (!extraBindings.empty()) {
+    if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
+      return failure();
+  }
+
+  if (extraBindingTypes.size() != extraBindings.size()) {
+    return parser.emitError(parser.getNameLoc(),
+                            "expected types to be provided for all operands");
+  }
+
+  // The paren is truly optional.
+  (void)parser.parseOptionalRParen();
+  return success();
+}
+
+static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
+                                    Value root, Type rootType,
+                                    ValueRange extraBindings,
+                                    TypeRange extraBindingTypes) {
+  if (!root)
+    return;
+
+  printer << root;
+  bool hasExtras = !extraBindings.empty();
+  if (hasExtras) {
+    printer << ", ";
+    printer.printOperands(extraBindings);
+  }
+
+  printer << " : ";
+  if (hasExtras)
+    printer << "(";
+
+  printer << rootType;
+  if (hasExtras) {
+    printer << ", ";
+    llvm::interleaveComma(extraBindingTypes, printer.getStream());
+    printer << ")";
+  }
+}
+
 /// Returns `true` if the given op operand may be consuming the handle value in
 /// the Transform IR. That is, if it may have a Free effect on it.
 static bool isValueUsePotentialConsumer(OpOperand &use) {
@@ -691,22 +771,22 @@ checkDoubleConsume(Value value,
 }
 
 LogicalResult transform::SequenceOp::verify() {
-  assert(getBodyBlock()->getNumArguments() == 1 &&
-         "the number of arguments must have been verified to be 1 by "
+  assert(getBodyBlock()->getNumArguments() >= 1 &&
+         "the number of arguments must have been verified to be more than 1 by "
          "PossibleTopLevelTransformOpTrait");
 
-  BlockArgument arg = getBodyBlock()->getArgument(0);
-  if (getRoot()) {
-    if (arg.getType() != getRoot().getType()) {
-      return emitOpError() << "expects the type of the block argument to match "
-                              "the type of the operand";
-    }
+  if (!getRoot() && !getExtraBindings().empty()) {
+    return emitOpError()
+           << "does not expect extra operands when used as top-level";
   }
 
-  // Check if the block argument has more than one consuming use.
-  if (failed(checkDoubleConsume(
-          arg, [this]() { return (emitOpError() << "block argument #0"); }))) {
-    return failure();
+  // Check if a block argument has more than one consuming use.
+  for (BlockArgument arg : getBodyBlock()->getArguments()) {
+    if (failed(checkDoubleConsume(arg, [this, arg]() {
+          return (emitOpError() << "block argument #" << arg.getArgNumber());
+        }))) {
+      return failure();
+    }
   }
 
   // Check properties of the nested operations they cannot check themselves.
@@ -740,26 +820,26 @@ LogicalResult transform::SequenceOp::verify() {
   return success();
 }
 
+/// Appends to `effects` the memory effect instances on `target` with the same
+/// resource and effect as the ones the operation `iface` having on `source`.
+static void
+remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
+             SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  SmallVector<MemoryEffects::EffectInstance> nestedEffects;
+  iface.getEffectsOnValue(source, nestedEffects);
+  for (const auto &effect : nestedEffects)
+    effects.emplace_back(effect.getEffect(), target, effect.getResource());
+}
+
 void transform::SequenceOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  auto *mappingResource = TransformMappingResource::get();
-  effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource);
-
-  for (Value result : getResults()) {
-    effects.emplace_back(MemoryEffects::Allocate::get(), result,
-                         mappingResource);
-    effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource);
-  }
+  onlyReadsHandle(getRoot(), effects);
+  onlyReadsHandle(getExtraBindings(), effects);
+  producesHandle(getResults(), effects);
 
   if (!getRoot()) {
     for (Operation &op : *getBodyBlock()) {
-      auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
-      if (!iface) {
-        // TODO: fill all possible effects; or require ops to actually implement
-        // the memory effect interface always
-        assert(false);
-      }
-
+      auto iface = cast<MemoryEffectOpInterface>(&op);
       SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
       iface.getEffects(effects);
     }
@@ -769,24 +849,20 @@ void transform::SequenceOp::getEffects(
   // Carry over all effects on the argument of the entry block as those on the
   // operand, this is the same value just remapped.
   for (Operation &op : *getBodyBlock()) {
-    auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
-    if (!iface) {
-      // TODO: fill all possible effects; or require ops to actually implement
-      // the memory effect interface always
-      assert(false);
-    }
+    auto iface = cast<MemoryEffectOpInterface>(&op);
 
-    SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
-    iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects);
-    for (const auto &effect : nestedEffects)
-      effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource());
+    remapEffects(iface, getBodyBlock()->getArgument(0), getRoot(), effects);
+    for (auto [source, target] : llvm::zip(
+             getBodyBlock()->getArguments().drop_front(), getExtraBindings())) {
+      remapEffects(iface, source, target, effects);
+    }
   }
 }
 
 OperandRange transform::SequenceOp::getSuccessorEntryOperands(
     std::optional<unsigned> index) {
   assert(index && *index == 0 && "unexpected region index");
-  if (getOperation()->getNumOperands() == 1)
+  if (getOperation()->getNumOperands() > 0)
     return getOperation()->getOperands();
   return OperandRange(getOperation()->operand_end(),
                       getOperation()->operand_end());
@@ -813,21 +889,51 @@ void transform::SequenceOp::getRegionInvocationBounds(
   bounds.emplace_back(1, 1);
 }
 
+template <typename FnTy>
+static void buildSequenceBody(OpBuilder &builder, OperationState &state,
+                              Type bbArgType, TypeRange extraBindingTypes,
+                              FnTy bodyBuilder) {
+  SmallVector<Type> types;
+  types.reserve(1 + extraBindingTypes.size());
+  types.push_back(bbArgType);
+  llvm::append_range(types, extraBindingTypes);
+
+  OpBuilder::InsertionGuard guard(builder);
+  Region *region = state.regions.back().get();
+  Block *bodyBlock = builder.createBlock(region, region->begin(),
+                                         extraBindingTypes, {state.location});
+
+  // Populate body.
+  builder.setInsertionPointToStart(bodyBlock);
+  if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
+    bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
+  } else {
+    bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
+                bodyBlock->getArguments().drop_front());
+  }
+}
+
 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
                                   TypeRange resultTypes,
                                   FailurePropagationMode failurePropagationMode,
                                   Value root,
                                   SequenceBodyBuilderFn bodyBuilder) {
-  build(builder, state, resultTypes, failurePropagationMode, root);
-  Region *region = state.regions.back().get();
+  build(builder, state, resultTypes, failurePropagationMode, root,
+        /*extraBindings=*/ValueRange());
   Type bbArgType = root.getType();
-  OpBuilder::InsertionGuard guard(builder);
-  Block *bodyBlock = builder.createBlock(
-      region, region->begin(), TypeRange{bbArgType}, {state.location});
+  buildSequenceBody(builder, state, bbArgType,
+                    /*extraBindingTypes=*/TypeRange(), bodyBuilder);
+}
 
-  // Populate body.
-  builder.setInsertionPointToStart(bodyBlock);
-  bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
+void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
+                                  TypeRange resultTypes,
+                                  FailurePropagationMode failurePropagationMode,
+                                  Value root, ValueRange extraBindings,
+                                  SequenceBodyBuilderArgsFn bodyBuilder) {
+  build(builder, state, resultTypes, failurePropagationMode, root,
+        extraBindings);
+  buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
+                    bodyBuilder);
 }
 
 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
@@ -835,15 +941,20 @@ void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
                                   FailurePropagationMode failurePropagationMode,
                                   Type bbArgType,
                                   SequenceBodyBuilderFn bodyBuilder) {
-  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value());
-  Region *region = state.regions.back().get();
-  OpBuilder::InsertionGuard guard(builder);
-  Block *bodyBlock = builder.createBlock(
-      region, region->begin(), TypeRange{bbArgType}, {state.location});
+  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
+        /*extraBindings=*/ValueRange());
+  buildSequenceBody(builder, state, bbArgType,
+                    /*extraBindingTypes=*/TypeRange(), bodyBuilder);
+}
 
-  // Populate body.
-  builder.setInsertionPointToStart(bodyBlock);
-  bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
+void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
+                                  TypeRange resultTypes,
+                                  FailurePropagationMode failurePropagationMode,
+                                  Type bbArgType, TypeRange extraBindingTypes,
+                                  SequenceBodyBuilderArgsFn bodyBuilder) {
+  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
+        /*extraBindings=*/ValueRange());
+  buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
 }
 
 //===----------------------------------------------------------------------===//
index 5cd57b0..593b885 100644 (file)
@@ -89,7 +89,9 @@ class ReplicateOp:
 class SequenceOp:
 
   def __init__(self, failure_propagation_mode, results: Sequence[Type],
-               target: Union[Operation, Value, Type]):
+               target: Union[Operation, Value, Type],
+               extra_bindings: Optional[Union[Sequence[Value], Sequence[Type],
+                                              Operation, OpView]] = None):
     root = _get_op_result_or_value(target) if isinstance(
         target, (Operation, Value)) else None
     root_type = root.type if not isinstance(target, Type) else target
@@ -98,10 +100,25 @@ class SequenceOp:
           IntegerType.get_signless(32), failure_propagation_mode._as_int())
     else:
       failure_propagation_mode = failure_propagation_mode
+
+    if extra_bindings is None:
+      extra_bindings = []
+    if isinstance(extra_bindings, (Operation, OpView)):
+      extra_bindings = _get_op_results_or_values(extra_bindings)
+
+    extra_binding_types = []
+    if len(extra_bindings) != 0:
+      if isinstance(extra_bindings[0], Type):
+        extra_binding_types = extra_bindings
+        extra_bindings = []
+      else:
+        extra_binding_types = [v.type for v in extra_bindings]
+
     super().__init__(results_=results,
                      failure_propagation_mode=failure_propagation_mode_attr,
-                     root=root)
-    self.regions[0].blocks.append(root_type)
+                     root=root,
+                     extra_bindings=extra_bindings)
+    self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
 
   @property
   def body(self) -> Block:
@@ -111,6 +128,10 @@ class SequenceOp:
   def bodyTarget(self) -> Value:
     return self.body.arguments[0]
 
+  @property
+  def bodyExtraArgs(self) -> BlockArgumentList:
+    return self.body.arguments[1:]
+
 
 class WithPDLPatternsOp:
 
diff --git a/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir b/mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir
new file mode 100644 (file)
index 0000000..447c6b4
--- /dev/null
@@ -0,0 +1,71 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-ops=func.func bind-second-extra-to-ops=func.return})' \
+// RUN:             --split-input-file --verify-diagnostics
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
+  transform.test_print_remark_at_operand %arg1, "first extra" : !transform.any_op
+  transform.test_print_remark_at_operand %arg2, "second extra" : !transform.any_op
+}
+
+// expected-remark @below {{first extra}}
+func.func @foo() {
+  // expected-remark @below {{second extra}}
+  return
+}
+
+// expected-remark @below {{first extra}}
+func.func @bar(%arg0: i1) {
+  cf.cond_br %arg0, ^bb1, ^bb2
+^bb1:
+  // expected-remark @below {{second extra}}
+  return
+^bb2:
+  // expected-remark @below {{second extra}}
+  return
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.param<i64>):
+  // expected-error @above {{wrong kind of value provided for top-level parameter}}
+}
+
+func.func @foo() {
+  return
+}
+
+// -----
+
+// expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op):
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
+  transform.sequence %arg0, %arg1, %arg2 : !transform.any_op, !transform.any_op, !transform.any_op failures(propagate) {
+  ^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
+    transform.test_print_remark_at_operand %arg4, "first extra" : !transform.any_op
+    transform.test_print_remark_at_operand %arg5, "second extra" : !transform.any_op
+  }
+}
+
+// expected-remark @below {{first extra}}
+func.func @foo() {
+  // expected-remark @below {{second extra}}
+  return
+}
+
+// expected-remark @below {{first extra}}
+func.func @bar(%arg0: i1) {
+  cf.cond_br %arg0, ^bb1, ^bb2
+^bb1:
+  // expected-remark @below {{second extra}}
+  return
+^bb2:
+  // expected-remark @below {{second extra}}
+  return
+}
diff --git a/mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir b/mlir/test/Dialect/Transform/multi-arg-top-level-params.mlir
new file mode 100644 (file)
index 0000000..f5d7f8f
--- /dev/null
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-params=1,2,3 bind-second-extra-to-params=42,45})' \
+// RUN:          --split-input-file --verify-diagnostics
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation, %arg1: !transform.param<i64>, %arg2: !transform.param<i64>):
+  // expected-remark @below {{1 : i64, 2 : i64, 3 : i64}}
+  transform.test_print_param %arg1 : !transform.param<i64>
+  // expected-remark @below {{42 : i64, 45 : i64}}
+  transform.test_print_param %arg2 : !transform.param<i64>
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation, %arg1: !transform.any_op, %arg2: !transform.param<i64>):
+  // expected-error @above {{wrong kind of value provided for top-level operation handle}}
+}
+
+// -----
+
+// expected-error @below {{operation expects 3 extra value bindings, but 2 were provided to the interpreter}}
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation, %arg1: !transform.param<i64>, %arg2: !transform.param<i64>, %arg3: !transform.param<i64>):
+}
index e957d7a..2fd0a37 100644 (file)
@@ -1,15 +1,22 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
-// expected-error @below {{expects the entry block to have one argument of type implementing TransformHandleTypeInterface}}
+// expected-error @below {{expects the entry block to have at least one argument}}
 transform.sequence failures(propagate) {
 }
 
 // -----
 
+// expected-error @below {{expects the first entry block argument to be of type implementing TransformHandleTypeInterface}}
+transform.sequence failures(propagate) {
+^bb0(%rag0: i64):
+}
+
+// -----
+
 // expected-note @below {{nested in another possible top-level op}}
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
-  // expected-error @below {{expects the root operation to be provided for a nested op}}
+  // expected-error @below {{expects operands to be provided for a nested op}}
   transform.sequence failures(propagate) {
   ^bb1(%arg1: !pdl.operation):
   }
@@ -17,6 +24,14 @@ transform.sequence failures(propagate) {
 
 // -----
 
+// expected-error @below {{'transform.sequence' op expects trailing entry block arguments to be of type implementing TransformHandleTypeInterface or TransformParamTypeInterface}}
+// expected-note @below {{argument #1 does not}}
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: i64):
+}
+
+// -----
+
 // expected-error @below {{expected children ops to implement TransformOpInterface}}
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
@@ -46,10 +61,29 @@ transform.sequence failures(propagate) {
 
 // -----
 
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
+  // expected-error @below {{expected types to be provided for all operands}}
+  transform.sequence %arg0, %arg1, %arg2 : (!transform.any_op, !transform.any_op) failures(propagate) {
+  ^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
+  }
+}
+
+// -----
+
+%0 = "test.generate_something"() : () -> !transform.any_op
+// expected-error @below {{does not expect extra operands when used as top-level}}
+"transform.sequence"(%0) ({
+^bb0(%arg0: !transform.any_op):
+  "transform.yield"() : () -> ()
+}) {failure_propagation_mode = 1 : i32, operand_segment_sizes = array<i32: 0, 1>} : (!transform.any_op) -> ()
+
+// -----
+
 // expected-note @below {{nested in another possible top-level op}}
 transform.with_pdl_patterns {
 ^bb0(%arg0: !pdl.operation):
-  // expected-error @below {{expects the root operation to be provided for a nested op}}
+  // expected-error @below {{expects operands to be provided for a nested op}}
   transform.sequence failures(propagate) {
   ^bb1(%arg1: !pdl.operation):
   }
@@ -190,7 +224,7 @@ transform.sequence failures(propagate) {
 
 // -----
 
-// expected-error @below {{expects the entry block to have one argument of type implementing TransformHandleTypeInterface}}
+// expected-error @below {{expects the entry block to have at least one argument}}
 transform.alternatives {
 ^bb0:
   transform.yield
index 0d27f92..73171a8 100644 (file)
@@ -50,6 +50,33 @@ transform.sequence failures(propagate) {
   }
 }
 
+// CHECK: transform.sequence failures(propagate)
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
+  // CHECK: sequence %{{.*}}, %{{.*}}, %{{.*}} : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate)
+  transform.sequence %arg0, %arg1, %arg2 : !transform.any_op, !transform.any_op, !transform.any_op failures(propagate) {
+  ^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
+  }
+}
+
+// CHECK: transform.sequence failures(propagate)
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
+  // CHECK: sequence %{{.*}}, %{{.*}}, %{{.*}} : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate)
+  transform.sequence %arg0, %arg1, %arg2 : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate) {
+  ^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
+  }
+}
+
+// CHECK: transform.sequence failures(propagate)
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
+  // CHECK: sequence %{{.*}}, %{{.*}}, %{{.*}} : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate)
+  transform.sequence %arg0, %arg1, %arg2 : (!transform.any_op, !transform.any_op, !transform.any_op) failures(propagate) {
+  ^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
+  }
+}
+
 // CHECK: transform.sequence
 // CHECK: foreach
 transform.sequence failures(propagate) {
index 1696cae..7d049eb 100644 (file)
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Pass/Pass.h"
 
@@ -39,12 +40,72 @@ public:
     return "apply transform dialect operations one by one";
   }
 
+  ArrayRef<transform::MappedValue>
+  findOperationsByName(Operation *root, StringRef name,
+                       SmallVectorImpl<transform::MappedValue> &storage) {
+    size_t start = storage.size();
+    root->walk([&](Operation *op) {
+      if (op->getName().getStringRef() == name) {
+        storage.push_back(op);
+      }
+    });
+    return ArrayRef(storage).drop_front(start);
+  }
+
+  ArrayRef<transform::MappedValue>
+  createParameterMapping(MLIRContext &context, ArrayRef<int> values,
+                         SmallVectorImpl<transform::MappedValue> &storage) {
+    size_t start = storage.size();
+    llvm::append_range(storage, llvm::map_range(values, [&](int v) {
+                         Builder b(&context);
+                         return transform::MappedValue(b.getI64IntegerAttr(v));
+                       }));
+    return ArrayRef(storage).drop_front(start);
+  }
+
   void runOnOperation() override {
+    if (!bindFirstExtraToOps.empty() && !bindFirstExtraToParams.empty()) {
+      emitError(UnknownLoc::get(&getContext()))
+          << "cannot bind the first extra top-level argument to both "
+             "operations and parameters";
+      return signalPassFailure();
+    }
+    if (!bindSecondExtraToOps.empty() && !bindSecondExtraToParams.empty()) {
+      emitError(UnknownLoc::get(&getContext()))
+          << "cannot bind the second extra top-level argument to both "
+             "operations and parameters";
+      return signalPassFailure();
+    }
+    if ((!bindSecondExtraToOps.empty() || !bindSecondExtraToParams.empty()) &&
+        bindFirstExtraToOps.empty() && bindFirstExtraToParams.empty()) {
+      emitError(UnknownLoc::get(&getContext()))
+          << "cannot bind the second extra top-level argument without binding "
+             "the first";
+      return signalPassFailure();
+    }
+
+    SmallVector<transform::MappedValue> extraMappingStorage;
+    SmallVector<ArrayRef<transform::MappedValue>> extraMapping;
+    if (!bindFirstExtraToOps.empty()) {
+      extraMapping.push_back(findOperationsByName(
+          getOperation(), bindFirstExtraToOps.getValue(), extraMappingStorage));
+    } else if (!bindFirstExtraToParams.empty()) {
+      extraMapping.push_back(createParameterMapping(
+          getContext(), bindFirstExtraToParams, extraMappingStorage));
+    }
+    if (!bindSecondExtraToOps.empty()) {
+      extraMapping.push_back(findOperationsByName(
+          getOperation(), bindSecondExtraToOps, extraMappingStorage));
+    } else if (!bindSecondExtraToParams.empty()) {
+      extraMapping.push_back(createParameterMapping(
+          getContext(), bindSecondExtraToParams, extraMappingStorage));
+    }
+
     ModuleOp module = getOperation();
     for (auto op :
          module.getBody()->getOps<transform::TransformOpInterface>()) {
       if (failed(transform::applyTransforms(
-              module, op,
+              module, op, extraMapping,
               transform::TransformOptions().enableExpensiveChecks(
                   enableExpensiveChecks))))
         return signalPassFailure();
@@ -55,6 +116,24 @@ public:
       *this, "enable-expensive-checks", llvm::cl::init(false),
       llvm::cl::desc("perform expensive checks to better report errors in the "
                      "transform IR")};
+
+  Option<std::string> bindFirstExtraToOps{
+      *this, "bind-first-extra-to-ops",
+      llvm::cl::desc("bind the first extra argument of the top-level op to "
+                     "payload operations of the given kind")};
+  ListOption<int> bindFirstExtraToParams{
+      *this, "bind-first-extra-to-params",
+      llvm::cl::desc("bind the first extra argument of the top-level op to "
+                     "the given integer parameters")};
+
+  Option<std::string> bindSecondExtraToOps{
+      *this, "bind-second-extra-to-ops",
+      llvm::cl::desc("bind the second extra argument of the top-level op to "
+                     "payload operations of the given kind")};
+  ListOption<int> bindSecondExtraToParams{
+      *this, "bind-second-extra-to-params",
+      llvm::cl::desc("bind the second extra argument of the top-level op to "
+                     "the given integer parameters")};
 };
 
 struct TestTransformDialectEraseSchedulePass
index c2ee6c1..ed6b68e 100644 (file)
@@ -70,6 +70,38 @@ def testNestedSequenceOp():
 
 
 @run
+def testSequenceOpWithExtras():
+  sequence = transform.SequenceOp(
+      transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get(),
+      [transform.AnyOpType.get(),
+       transform.OperationType.get("foo.bar")])
+  with InsertionPoint(sequence.body):
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testSequenceOpWithExtras
+  # CHECK: transform.sequence failures(propagate)
+  # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
+
+
+@run
+def testNestedSequenceOpWithExtras():
+  sequence = transform.SequenceOp(
+      transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get(),
+      [transform.AnyOpType.get(),
+       transform.OperationType.get("foo.bar")])
+  with InsertionPoint(sequence.body):
+    nested = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
+                                  [], sequence.bodyTarget,
+                                  sequence.bodyExtraArgs)
+    with InsertionPoint(nested.body):
+      transform.YieldOp()
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
+  # CHECK: transform.sequence failures(propagate)
+  # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
+  # CHECK:   sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
+
+
+@run
 def testTransformPDLOps():
   withPdl = transform.WithPDLPatternsOp(pdl.OperationType.get())
   with InsertionPoint(withPdl.body):