[mlir] Transform dialect: add named sequences
authorAlex Zinenko <zinenko@google.com>
Mon, 20 Mar 2023 16:18:35 +0000 (16:18 +0000)
committerAlex Zinenko <zinenko@google.com>
Tue, 21 Mar 2023 14:53:54 +0000 (14:53 +0000)
Named sequences introduce an additional abstraction and reuse capability
to the transform dialect. They can be though of as macros parameterized
with handles that can be invoked in places where a transform dialect
operation is expected. Such reuse was previously not possible in the
dialect and required dynamic construction of the transform IR from the
client language. Named sequences are intentionally restricted to
disallow recursion, as it could make the dialect accidentally
Turing-complete, which isn't desired at this point.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D146433

mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
mlir/test/Dialect/Transform/ops-invalid.mlir
mlir/test/Dialect/Transform/test-interpreter.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index 0c8c0b6..639a7c7 100644 (file)
@@ -23,8 +23,19 @@ def Transform_Dialect : Dialect {
     "::mlir::pdl_interp::PDLInterpDialect",
   ];
 
+  let hasOperationAttrVerify = 1;
 
   let extraClassDeclaration = [{
+      /// Name of the attribute attachable to the symbol table operation
+      /// containing named sequences. This is used to trigger verification.
+      constexpr const static llvm::StringLiteral
+          kWithNamedSequenceAttrName = "transform.with_named_sequence";
+
+      /// Names of the attribute attachable to an operation so it can be
+      /// identified as root by the default interpreter pass.
+      constexpr const static llvm::StringLiteral
+          kTargetTagAttrName = "transform.target_tag";
+
       /// Returns the named PDL constraint functions available in the dialect
       /// as a map from their name to the function.
       const ::llvm::StringMap<::mlir::PDLConstraintFunction> &
index b2332c8..78a812e 100644 (file)
@@ -192,6 +192,12 @@ public:
   // class body to comply with visibility and full-declaration requirements.
   inline RegionScope make_region_scope(Region &region);
 
+  /// Creates a new region scope for the given isolated-from-above region.
+  /// Unlike the non-isolated counterpart, there is no nesting expectation.
+  // Implementation note: this method is inline but implemented outside of the
+  // class body to comply with visibility and full-declaration requirements
+  inline RegionScope make_isolated_region_scope(Region &region);
+
   /// A RAII object maintaining a "stack frame" for a transform IR region. When
   /// applying a transform IR operation that contains a region, the caller is
   /// expected to create a RegionScope before applying the ops contained in the
@@ -201,17 +207,23 @@ public:
   class RegionScope {
   public:
     /// Forgets the mapping from or to values defined in the associated
-    /// transform IR region.
+    /// transform IR region, and restores the mapping that existed before
+    /// entering this scope.
     ~RegionScope() {
       state.mappings.erase(region);
+      if (storedMappings.has_value())
+        state.mappings.swap(*storedMappings);
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
       state.regionStack.pop_back();
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
     }
 
   private:
+    /// Tag structure for differentiating the constructor for isolated regions.
+    struct Isolated {};
+
     /// Creates a new scope for mappings between values defined in the given
-    /// transform IR region and payload IR operations.
+    /// transform IR region and payload IR objects.
     RegionScope(TransformState &state, Region &region)
         : state(state), region(&region) {
       auto res = state.mappings.try_emplace(this->region);
@@ -225,13 +237,33 @@ public:
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
     }
 
+    /// Creates a new scope for mappings between values defined in the given
+    /// isolated-from-above transform IR region and payload IR objects.
+    RegionScope(TransformState &state, Region &region, Isolated)
+        : state(state), region(&region) {
+      // Store the previous mapping stack locally.
+      storedMappings = llvm::SmallDenseMap<Region *, Mappings>();
+      storedMappings->swap(state.mappings);
+      state.mappings.try_emplace(this->region);
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+      state.regionStack.push_back(this->region);
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+    }
+
     /// Back-reference to the transform state.
     TransformState &state;
 
     /// The region this scope is associated with.
     Region *region;
 
+    /// Local copy of the mappings that existed before entering the current
+    /// region. Used only when the current region is isolated so we don't
+    /// accidentally look up the values defined outside the isolated region.
+    std::optional<llvm::SmallDenseMap<Region *, Mappings>> storedMappings =
+        std::nullopt;
+
     friend RegionScope TransformState::make_region_scope(Region &);
+    friend RegionScope TransformState::make_isolated_region_scope(Region &);
   };
   friend class RegionScope;
 
@@ -551,6 +583,13 @@ public:
   /// TransformValueHandleTypeInterface.
   void setValues(OpResult handle, ValueRange values);
 
+  /// Indicates that the result of the transform IR op at the given position
+  /// corresponds to the given range of mapped values. All mapped values are
+  /// expected to be compatible with the type of the result, e.g., if the result
+  /// is an operation handle, all mapped values are expected to be payload
+  /// operations.
+  void setMappedValues(OpResult handle, ArrayRef<MappedValue> values);
+
 private:
   /// Creates an instance of TransformResults that expects mappings for
   /// `numSegments` values, which may be associated with payload operations or
@@ -597,10 +636,21 @@ private:
   RaggedArray<Value> values;
 };
 
+/// Creates a RAII object the lifetime of which corresponds to the new mapping
+/// for transform IR values defined in the given region. Values defined in
+/// surrounding regions remain accessible.
 TransformState::RegionScope TransformState::make_region_scope(Region &region) {
   return RegionScope(*this, region);
 }
 
+/// Creates a RAII object the lifetime of which corresponds to the new mapping
+/// for transform IR values defined in the given isolated-from-above region.
+/// Values defined in surrounding regions cannot be accessed.
+TransformState::RegionScope
+TransformState::make_isolated_region_scope(Region &region) {
+  return RegionScope(*this, region, RegionScope::Isolated());
+}
+
 namespace detail {
 /// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait
 /// to either the list of operations associated with its operand or the root of
@@ -614,6 +664,12 @@ LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
 
 /// Verification hook for TransformOpInterface.
 LogicalResult verifyTransformOpInterface(Operation *op);
+
+/// Populates `mappings` with mapped values associated with the given transform
+/// IR values in the given `state`.
+void prepareValueMappings(
+    SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
+    ValueRange values, const transform::TransformState &state);
 } // namespace detail
 
 /// This trait is supposed to be attached to Transform dialect operations that
index 7eb9a01..2424b16 100644 (file)
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/IR/FunctionInterfaces.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 
index 8865865..3ffc3f7 100644 (file)
@@ -9,10 +9,12 @@
 #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS
 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS
 
+include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/FunctionInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Dialect/Transform/IR/TransformAttrs.td"
@@ -266,6 +268,51 @@ def GetResultOp : TransformDialectOp<"get_result",
                        "functional-type(operands, results)";
 }
 
+def IncludeOp : TransformDialectOp<"include",
+    [CallOpInterface,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let summary = "Includes a named transform sequence";
+  let description = [{
+    The application of this transform operation is equivalent to applying the
+    operations contained in the named transform sequence with operands being
+    remapped to block arguments. The behavior of the operation when a
+    transformation in the included named sequence produces a silenceable error
+    is controlled by the `failure_propagation_mode` attribute. When set to
+    `propagate`, the failure of any nested transformation in the sequence
+    implies immediate failure of the entire sequence with a silenceable error,
+    and no further transformation is attempted. When set to `suppress`,
+    silenceable errors in nested operations are ignored and further
+    transformations are applied. Beware that even silenceable errors may leave
+    the payload IR in a state unsuitable for further transformations. It is the
+    responsibility of the user to ensure the following transformations are
+    robust enough when errors are suppressed. Definite errors are propagated
+    immediately regardless of the mode. The objects associated with the results
+    of this operation are the same as those associated with the operands of the
+    `transform.yield` in the referenced named sequence.
+  }];
+
+  let arguments = (ins SymbolRefAttr:$target,
+                       FailurePropagationMode:$failure_propagation_mode,
+                       Variadic<Transform_AnyHandleOrParamType>:$operands);
+  let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
+  
+  let assemblyFormat =
+      "$target `failures` `(` $failure_propagation_mode `)`"
+      "`(` $operands `)` attr-dict `:` functional-type($operands, $results)";
+
+  let extraClassDeclaration = [{
+    ::mlir::CallInterfaceCallable getCallableForCallee() {
+      return getTarget();
+    }
+
+    ::mlir::Operation::operand_range getArgOperands() {
+      return getOperands();
+    }
+  }];
+}
+
 def MergeHandlesOp : TransformDialectOp<"merge_handles",
     [DeclareOpInterfaceMethods<TransformOpInterface, ["allowsRepeatedHandleOperands"]>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -289,6 +336,67 @@ def MergeHandlesOp : TransformDialectOp<"merge_handles",
   let hasFolder = 1;
 }
 
+def NamedSequenceOp : TransformDialectOp<"named_sequence",
+    [CallableOpInterface,
+     FunctionOpInterface,
+     IsolatedFromAbove,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let summary = "Named transform sequence that can be included elsewhere";
+  let description = [{
+    Defines a named (callable, function-like) sequence of other Transform
+    dialect operations that can be included using `transform.include` as part of
+    another Transform dialect construct. This sequence is not processed
+    immediately but rather dispatched to when the inclusion is processed. The
+    arguments and results can be used to communicate a subset of mapping into
+    the named sequence. The sequence must consist of a single block and end with
+    a `transform.yield` terminator. The operands of the terminator become the
+    results of the `transform.include`.
+
+    When dispatched to, the operations in the named sequence are executed one by
+    one, similarly to the regular unnamed sequence. The failure propagation mode
+    is specified on the `transform.include`. Different inclusions may use
+    different failure propagation modes. This transform operation always
+    succeeds by itself, but the inclusion may fail if any of the operations
+    fail.
+
+    Named sequences can only appear at the top-level of the Transform dialect
+    nesting structure. That is, they cannot be nested in other Transform dialect
+    operations. Furthermore, one of the ancestors must have the `SymbolTable`
+    trait and have the `transform.with_named_sequence` attribute attached.
+
+    Named sequences may include other named sequences via `transform.include`,
+    but recursion is *not* allowed.
+  }];
+  
+  let arguments = (ins
+    SymbolNameAttr:$sym_name,
+    TypeAttrBase<"::mlir::FunctionType",
+                 "function type attribute">:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs);
+  let regions = (region SizedRegion<1>:$body);
+
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    ::llvm::ArrayRef<::mlir::Type> getArgumentTypes() {
+      return getFunctionType().getInputs();
+    }
+    ::llvm::ArrayRef<::mlir::Type> getResultTypes() {
+      return getFunctionType().getResults();
+    }
+
+    ::mlir::Region *getCallableRegion() {
+      return &getBody();
+    }
+    ::llvm::ArrayRef<::mlir::Type> getCallableResults() {
+      return getFunctionType().getResults();
+    }
+  }];
+}
+
 def SplitHandlesOp : TransformDialectOp<"split_handles",
     [FunctionalStyleTransformOpTrait,
      DeclareOpInterfaceMethods<TransformOpInterface>,
@@ -376,7 +484,6 @@ def PrintOp : TransformDialectOp<"print",
   let assemblyFormat = "$target attr-dict (`:` type($target)^)?";
 }
 
-
 def ReplicateOp : TransformDialectOp<"replicate",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -426,21 +533,21 @@ def SequenceOp : TransformDialectOp<"sequence",
   let description = [{
     The transformations indicated by the sequence are applied in order of their
     appearance. Each value produced by a transformation within the sequence
-    corresponds to an operation or a group of operations in the payload IR.
-    The behavior of the operation when a nested transformation produces a
-    silenceable error is controlled by the `failure_propagation_mode` attribute.
-    When set to `propagate`, the failure of any nested transformation in the
-    sequence implies immediate failure of the entire sequence with a silenceable
-    error, and no further transformation is attempted. When set to `suppress`,
+    corresponds to a group of operations or values in the payload IR, or to a
+    group of parameters, depending on the type of the value. The behavior of the
+    operation when a nested transformation produces a silenceable error is
+    controlled by the `failure_propagation_mode` attribute. When set to
+    `propagate`, the failure of any nested transformation in the sequence
+    implies immediate failure of the entire sequence with a silenceable error,
+    and no further transformation is attempted. When set to `suppress`,
     silenceable errors in nested operations are ignored and further
     transformations are applied. Beware that even silenceable errors may leave
-    the payload IR in a state unsuitable for further transformations. It is
-    the responsibility of the caller to ensure the following transformations
-    are robust enough when errors are suppressed. Definite errors reported by
-    nested transformations abort the sequence regardless of the propagation
-    mode. The set of modes may be extended in the future, e.g., to collect
-    silenceable errors and report them after attempting all transformations in
-    the sequence.
+    the payload IR in a state unsuitable for further transformations. It is the
+    responsibility of the caller to ensure the following transformations are
+    robust enough when errors are suppressed. Definite errors reported by nested
+    transformations abort the sequence regardless of the propagation mode. The
+    set of modes may be extended in the future, e.g., to collect silenceable
+    errors and report them after attempting all transformations in the sequence.
 
     The entry block of this operation has a single argument that maps to either
     the operand if provided or the top-level container operation of the payload
@@ -565,7 +672,8 @@ def YieldOp : TransformDialectOp<"yield",
   }];
 
   let arguments = (ins
-    Arg<Variadic<TransformHandleTypeInterface>, "Operation handles yielded back to the parent"
+    Arg<Variadic<Transform_AnyHandleOrParamType>,
+        "Transform values yielded back to the parent"
         >:$operands);
   let assemblyFormat = "operands attr-dict (`:` type($operands)^)?";
 
index 1f61ecd..99ff80e 100644 (file)
@@ -7,12 +7,14 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Analysis/CallGraph.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/SCCIterator.h"
 
 using namespace mlir;
 
@@ -128,4 +130,53 @@ void transform::TransformDialect::reportDuplicateOpRegistration(
   llvm::report_fatal_error(StringRef(buffer));
 }
 
+LogicalResult transform::TransformDialect::verifyOperationAttribute(
+    Operation *op, NamedAttribute attribute) {
+  if (attribute.getName().getValue() == kWithNamedSequenceAttrName) {
+    if (!op->hasTrait<OpTrait::SymbolTable>()) {
+      return emitError(op->getLoc()) << attribute.getName()
+                                     << " attribute can only be attached to "
+                                        "operations with symbol tables";
+    }
+
+    const mlir::CallGraph callgraph(op);
+    for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) {
+      if (!scc.hasCycle())
+        continue;
+
+      // Need to check this here additionally because this verification may run
+      // before we check the nested operations.
+      if ((*scc->begin())->isExternal())
+        return op->emitOpError() << "contains a call to an external operation, "
+                                    "which is not allowed";
+
+      Operation *first = (*scc->begin())->getCallableRegion()->getParentOp();
+      InFlightDiagnostic diag = emitError(first->getLoc())
+                                << "recursion not allowed in named sequences";
+      for (auto it = std::next(scc->begin()); it != scc->end(); ++it) {
+        // Need to check this here additionally because this verification may
+        // run before we check the nested operations.
+        if ((*it)->isExternal()) {
+          return op->emitOpError() << "contains a call to an external "
+                                      "operation, which is not allowed";
+        }
+
+        Operation *current = (*it)->getCallableRegion()->getParentOp();
+        diag.attachNote(current->getLoc()) << "operation on recursion stack";
+      }
+      return diag;
+    }
+    return success();
+  }
+  if (attribute.getName().getValue() == kTargetTagAttrName) {
+    if (!attribute.getValue().isa<StringAttr>()) {
+      return op->emitError()
+             << attribute.getName() << " attribute must be a string";
+    }
+    return success();
+  }
+  return emitError(op->getLoc())
+         << "unknown attribute: " << attribute.getName();
+}
+
 #include "mlir/Dialect/Transform/IR/TransformDialectEnums.cpp.inc"
index 4002e59..c9045c5 100644 (file)
@@ -104,50 +104,77 @@ LogicalResult transform::TransformState::getHandlesForPayloadValue(
   return success(found);
 }
 
-LogicalResult
-transform::TransformState::mapBlockArgument(BlockArgument argument,
-                                            ArrayRef<MappedValue> values) {
-  if (argument.getType().isa<TransformHandleTypeInterface>()) {
+/// Given a list of MappedValues, cast them to the value kind implied by the
+/// interface of the handle type, and dispatch to one of the callbacks.
+static DiagnosedSilenceableFailure dispatchMappedValues(
+    Value handle, ArrayRef<transform::MappedValue> values,
+    function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn,
+    function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn,
+    function_ref<LogicalResult(ValueRange)> valuesFn) {
+  if (handle.getType().isa<transform::TransformHandleTypeInterface>()) {
     SmallVector<Operation *> operations;
     operations.reserve(values.size());
-    for (MappedValue value : values) {
+    for (transform::MappedValue value : values) {
       if (auto *op = value.dyn_cast<Operation *>()) {
         operations.push_back(op);
         continue;
       }
-      return emitError(argument.getLoc())
+      return emitSilenceableFailure(handle.getLoc())
              << "wrong kind of value provided for top-level operation handle";
     }
-    return setPayloadOps(argument, operations);
+    if (failed(operationsFn(operations)))
+      return DiagnosedSilenceableFailure::definiteFailure();
+    return DiagnosedSilenceableFailure::success();
   }
 
-  if (argument.getType().isa<TransformValueHandleTypeInterface>()) {
+  if (handle.getType().isa<transform::TransformValueHandleTypeInterface>()) {
     SmallVector<Value> payloadValues;
     payloadValues.reserve(values.size());
-    for (MappedValue value : values) {
+    for (transform::MappedValue value : values) {
       if (auto v = value.dyn_cast<Value>()) {
         payloadValues.push_back(v);
         continue;
       }
-      return emitError(argument.getLoc())
+      return emitSilenceableFailure(handle.getLoc())
              << "wrong kind of value provided for the top-level value handle";
     }
-    return setPayloadValues(argument, payloadValues);
+    if (failed(valuesFn(payloadValues)))
+      return DiagnosedSilenceableFailure::definiteFailure();
+    return DiagnosedSilenceableFailure::success();
   }
 
-  assert(argument.getType().isa<TransformParamTypeInterface>() &&
+  assert(handle.getType().isa<transform::TransformParamTypeInterface>() &&
          "unsupported kind of block argument");
-  SmallVector<Param> parameters;
+  SmallVector<transform::Param> parameters;
   parameters.reserve(values.size());
-  for (MappedValue value : values) {
+  for (transform::MappedValue value : values) {
     if (auto attr = value.dyn_cast<Attribute>()) {
       parameters.push_back(attr);
       continue;
     }
-    return emitError(argument.getLoc())
+    return emitSilenceableFailure(handle.getLoc())
            << "wrong kind of value provided for top-level parameter";
   }
-  return setParams(argument, parameters);
+  if (failed(paramsFn(parameters)))
+    return DiagnosedSilenceableFailure::definiteFailure();
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult
+transform::TransformState::mapBlockArgument(BlockArgument argument,
+                                            ArrayRef<MappedValue> values) {
+  return dispatchMappedValues(
+             argument, values,
+             [&](ArrayRef<Operation *> operations) {
+               return setPayloadOps(argument, operations);
+             },
+             [&](ArrayRef<Param> params) {
+               return setParams(argument, params);
+             },
+             [&](ValueRange payloadValues) {
+               return setPayloadValues(argument, payloadValues);
+             })
+      .checkAndReport();
 }
 
 LogicalResult
@@ -887,6 +914,27 @@ void transform::TransformResults::setValues(OpResult handle,
   this->values.replace(position, values);
 }
 
+void transform::TransformResults::setMappedValues(
+    OpResult handle, ArrayRef<MappedValue> values) {
+  DiagnosedSilenceableFailure diag = dispatchMappedValues(
+      handle, values,
+      [&](ArrayRef<Operation *> operations) {
+        return set(handle, operations), success();
+      },
+      [&](ArrayRef<Param> params) {
+        return setParams(handle, params), success();
+      },
+      [&](ValueRange payloadValues) {
+        return setValues(handle, payloadValues), success();
+      });
+#ifndef NDEBUG
+  if (!diag.succeeded())
+    llvm::dbgs() << diag.getStatusString() << "\n";
+  assert(diag.succeeded() && "incorrect mapping");
+#endif // NDEBUG
+  (void)diag.silence();
+}
+
 ArrayRef<Operation *>
 transform::TransformResults::get(unsigned resultNumber) const {
   assert(resultNumber < operations.size() &&
@@ -1029,24 +1077,30 @@ void transform::detail::setApplyToOneResults(
 // Utilities for PossibleTopLevelTransformOpTrait.
 //===----------------------------------------------------------------------===//
 
+void transform::detail::prepareValueMappings(
+    SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
+    ValueRange values, const transform::TransformState &state) {
+  for (Value operand : values) {
+    SmallVector<MappedValue> &mapped = mappings.emplace_back();
+    if (operand.getType().isa<TransformHandleTypeInterface>()) {
+      llvm::append_range(mapped, state.getPayloadOps(operand));
+    } else if (operand.getType().isa<TransformValueHandleTypeInterface>()) {
+      llvm::append_range(mapped, state.getPayloadValues(operand));
+    } else {
+      assert(operand.getType().isa<TransformParamTypeInterface>() &&
+             "unsupported kind of transform dialect value");
+      llvm::append_range(mapped, state.getParams(operand));
+    }
+  }
+}
+
 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
     TransformState &state, Operation *op, Region &region) {
   SmallVector<Operation *> targets;
   SmallVector<SmallVector<MappedValue>> extraMappings;
   if (op->getNumOperands() != 0) {
     llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
-    for (Value operand : op->getOperands().drop_front()) {
-      SmallVector<MappedValue> &mapped = extraMappings.emplace_back();
-      if (operand.getType().isa<TransformHandleTypeInterface>()) {
-        llvm::append_range(mapped, state.getPayloadOps(operand));
-      } else if (operand.getType().isa<TransformValueHandleTypeInterface>()) {
-        llvm::append_range(mapped, state.getPayloadValues(operand));
-      } else {
-        assert(operand.getType().isa<TransformParamTypeInterface>() &&
-               "unsupported kind of transform dialect value");
-        llvm::append_range(mapped, state.getParams(operand));
-      }
-    }
+    prepareValueMappings(extraMappings, op->getOperands().drop_front(), state);
   } else {
     if (state.getNumTopLevelMappings() !=
         region.front().getNumArguments() - 1) {
index cc4382a..6051007 100644 (file)
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/IR/FunctionImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
@@ -175,11 +176,19 @@ static void forwardEmptyOperands(Block *block, transform::TransformState &state,
 static void forwardTerminatorOperands(Block *block,
                                       transform::TransformState &state,
                                       transform::TransformResults &results) {
-  for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(),
-                                    block->getParentOp()->getOpResults())) {
-    Value terminatorOperand = std::get<0>(pair);
-    OpResult result = std::get<1>(pair);
-    results.set(result, state.getPayloadOps(terminatorOperand));
+  for (auto &&[terminatorOperand, result] :
+       llvm::zip(block->getTerminator()->getOperands(),
+                 block->getParentOp()->getOpResults())) {
+    if (result.getType().isa<transform::TransformHandleTypeInterface>()) {
+      results.set(result, state.getPayloadOps(terminatorOperand));
+    } else if (result.getType()
+                   .isa<transform::TransformValueHandleTypeInterface>()) {
+      results.setValues(result, state.getPayloadValues(terminatorOperand));
+    } else {
+      assert(result.getType().isa<transform::TransformParamTypeInterface>() &&
+             "unhandled transform type interface");
+      results.setParams(result, state.getParams(terminatorOperand));
+    }
   }
 }
 
@@ -525,6 +534,177 @@ transform::GetResultOp::apply(transform::TransformResults &results,
 }
 
 //===----------------------------------------------------------------------===//
+// IncludeOp
+//===----------------------------------------------------------------------===//
+
+/// Applies the transform ops contained in `block`. Maps `results` to the same
+/// values as the operands of the block terminator.
+static DiagnosedSilenceableFailure
+applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
+                   transform::TransformState &state,
+                   transform::TransformResults &results) {
+  // Apply the sequenced ops one by one.
+  for (Operation &transform : block.without_terminator()) {
+    DiagnosedSilenceableFailure result =
+        state.applyTransform(cast<transform::TransformOpInterface>(transform));
+    if (result.isDefiniteFailure())
+      return result;
+
+    if (result.isSilenceableFailure()) {
+      if (mode == transform::FailurePropagationMode::Propagate) {
+        // Propagate empty results in case of early exit.
+        forwardEmptyOperands(&block, state, results);
+        return result;
+      }
+      (void)result.silence();
+    }
+  }
+
+  // Forward the operation mapping for values yielded from the sequence to the
+  // values produced by the sequence op.
+  forwardTerminatorOperands(&block, state, results);
+  return DiagnosedSilenceableFailure::success();
+}
+
+DiagnosedSilenceableFailure
+transform::IncludeOp::apply(transform::TransformResults &results,
+                            transform::TransformState &state) {
+  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
+      getOperation(), getTarget());
+  assert(callee && "unverified reference to unknown symbol");
+
+  // Map operands to block arguments.
+  SmallVector<SmallVector<MappedValue>> mappings;
+  detail::prepareValueMappings(mappings, getOperands(), state);
+  auto scope = state.make_isolated_region_scope(callee.getBody());
+  for (auto &&[arg, map] :
+       llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
+    if (failed(state.mapBlockArgument(arg, map)))
+      return DiagnosedSilenceableFailure::definiteFailure();
+  }
+
+  DiagnosedSilenceableFailure result = applySequenceBlock(
+      callee.getBody().front(), getFailurePropagationMode(), state, results);
+  mappings.clear();
+  detail::prepareValueMappings(
+      mappings, callee.getBody().front().getTerminator()->getOperands(), state);
+  for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
+    results.setMappedValues(result, mapping);
+  return result;
+}
+
+/// Appends to `effects` the memory effect instances on `target` with the same
+/// resource and effect as the ones the operation `iface` having on `source`.
+static void
+remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
+             SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  SmallVector<MemoryEffects::EffectInstance> nestedEffects;
+  iface.getEffectsOnValue(source, nestedEffects);
+  for (const auto &effect : nestedEffects)
+    effects.emplace_back(effect.getEffect(), target, effect.getResource());
+}
+
+/// Appends to `effects` the same effects as the operations of `block` have on
+/// block arguments but associated with `operands.`
+static void
+remapArgumentEffects(Block &block, ValueRange operands,
+                     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  for (Operation &op : block) {
+    auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
+    if (!iface)
+      continue;
+
+    for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
+      remapEffects(iface, source, target, effects);
+    }
+
+    SmallVector<MemoryEffects::EffectInstance> nestedEffects;
+    iface.getEffectsOnResource(transform::PayloadIRResource::get(),
+                               nestedEffects);
+    llvm::append_range(effects, nestedEffects);
+  }
+}
+
+static DiagnosedSilenceableFailure
+verifyNamedSequenceOp(transform::NamedSequenceOp op);
+
+void transform::IncludeOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  // Bail if the callee is unknown. This may run as part of the verification
+  // process before we verified the validity of the callee or of this op.
+  auto target =
+      getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
+  if (!target)
+    return;
+  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
+      getOperation(), getTarget());
+  if (!callee)
+    return;
+  DiagnosedSilenceableFailure earlyVerifierResult =
+      verifyNamedSequenceOp(callee);
+  if (!earlyVerifierResult.succeeded()) {
+    (void)earlyVerifierResult.silence();
+    return;
+  }
+
+  // Carry over effects from the callee.
+  remapArgumentEffects(callee.getBody().front(), getOperands(), effects);
+
+  // Proper effects.
+  onlyReadsHandle(getOperands(), effects);
+  producesHandle(getResults(), effects);
+}
+
+template <typename... Tys>
+static bool implementSameInterface(Type t1, Type t2) {
+  return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
+}
+
+LogicalResult
+transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  // Access through indirection and do additional checking because this may be
+  // running before the main op verifier.
+  auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
+  if (!targetAttr)
+    return emitOpError() << "expects a 'target' symbol reference attribute";
+
+  auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
+      *this, targetAttr);
+  if (!target)
+    return emitOpError() << "does not reference a named transform sequence";
+
+  FunctionType fnType = target.getFunctionType();
+  if (fnType.getNumInputs() != getNumOperands())
+    return emitError("incorrect number of operands for callee");
+
+  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
+    if (getOperand(i).getType() != fnType.getInput(i)) {
+      return emitOpError("operand type mismatch: expected operand type ")
+             << fnType.getInput(i) << ", but provided "
+             << getOperand(i).getType() << " for operand number " << i;
+    }
+  }
+
+  if (fnType.getNumResults() != getNumResults())
+    return emitError("incorrect number of results for callee");
+
+  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
+    Type resultType = getResult(i).getType();
+    Type funcType = fnType.getResult(i);
+    if (!implementSameInterface<TransformHandleTypeInterface,
+                                TransformValueHandleTypeInterface,
+                                TransformParamTypeInterface>(resultType,
+                                                             funcType)) {
+      return emitOpError() << "type of result #" << i
+                           << " must implement the same transform dialect "
+                              "interface as the corresponding callee result";
+    }
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // MergeHandlesOp
 //===----------------------------------------------------------------------===//
 
@@ -568,6 +748,105 @@ OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
 }
 
 //===----------------------------------------------------------------------===//
+// NamedSequenceOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::NamedSequenceOp::apply(transform::TransformResults &results,
+                                  transform::TransformState &state) {
+  // Nothing to do here.
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::NamedSequenceOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
+
+ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser,
+                                              OperationState &result) {
+  return function_interface_impl::parseFunctionOp(
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name),
+      [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
+         function_interface_impl::VariadicFlag,
+         std::string &) { return builder.getFunctionType(inputs, results); },
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+}
+
+void transform::NamedSequenceOp::print(OpAsmPrinter &printer) {
+  function_interface_impl::printFunctionOp(
+      printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
+      getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
+      getResAttrsAttrName());
+}
+
+/// Verification of a NamedSequenceOp. This does not report the error
+/// immediately, so it can be used to check for op's well-formedness before the
+/// verifier runs, e.g., during trait verification.
+static DiagnosedSilenceableFailure
+verifyNamedSequenceOp(transform::NamedSequenceOp op) {
+  if (op.isExternal())
+    return emitSilenceableFailure(op) << "cannot be empty";
+
+  if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
+    if (!parent->getAttr(
+            transform::TransformDialect::kWithNamedSequenceAttrName)) {
+      DiagnosedSilenceableFailure diag =
+          emitSilenceableFailure(op)
+          << "expects the parent symbol table to have the '"
+          << transform::TransformDialect::kWithNamedSequenceAttrName
+          << "' attribute";
+      diag.attachNote(parent->getLoc()) << "symbol table operation";
+      return diag;
+    }
+  }
+
+  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableFailure(op)
+        << "cannot be defined inside another transform op";
+    diag.attachNote(parent.getLoc()) << "ancestor transform op";
+    return diag;
+  }
+
+  if (op.getBody().front().empty())
+    return emitSilenceableFailure(op) << "expected a non-empty body block";
+
+  Operation *terminator = &op.getBody().front().back();
+  if (!isa<transform::YieldOp>(terminator)) {
+    DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
+                                       << "expected '"
+                                       << transform::YieldOp::getOperationName()
+                                       << "' as terminator";
+    diag.attachNote(terminator->getLoc()) << "terminator";
+    return diag;
+  }
+
+  if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
+    return emitSilenceableFailure(terminator)
+           << "expected terminator to have as many operands as the parent op "
+              "has results";
+  }
+  for (auto [i, operandType, resultType] :
+       llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
+                       terminator->getOperands().getType(),
+                       op.getFunctionType().getResults())) {
+    if (operandType == resultType)
+      continue;
+    return emitSilenceableFailure(terminator)
+           << "the type of the terminator operand #" << i
+           << " must match the type of the corresponding parent op result ("
+           << operandType << " vs " << resultType << ")";
+  }
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::NamedSequenceOp::verify() {
+  // Actual verification happens in a separate function for reusability.
+  return verifyNamedSequenceOp(*this).checkAndReport();
+}
+
+//===----------------------------------------------------------------------===//
 // SplitHandlesOp
 //===----------------------------------------------------------------------===//
 
@@ -692,27 +971,8 @@ transform::SequenceOp::apply(transform::TransformResults &results,
   if (failed(mapBlockArguments(state)))
     return DiagnosedSilenceableFailure::definiteFailure();
 
-  // Apply the sequenced ops one by one.
-  for (Operation &transform : getBodyBlock()->without_terminator()) {
-    DiagnosedSilenceableFailure result =
-        state.applyTransform(cast<TransformOpInterface>(transform));
-    if (result.isDefiniteFailure())
-      return result;
-
-    if (result.isSilenceableFailure()) {
-      if (getFailurePropagationMode() == FailurePropagationMode::Propagate) {
-        // Propagate empty results in case of early exit.
-        forwardEmptyOperands(getBodyBlock(), state, results);
-        return result;
-      }
-      (void)result.silence();
-    }
-  }
-
-  // Forward the operation mapping for values yielded from the sequence to the
-  // values produced by the sequence op.
-  forwardTerminatorOperands(getBodyBlock(), state, results);
-  return DiagnosedSilenceableFailure::success();
+  return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
+                            results);
 }
 
 static ParseResult parseSequenceOpOperands(
@@ -871,22 +1131,6 @@ LogicalResult transform::SequenceOp::verify() {
   return success();
 }
 
-/// Appends to `effects` the memory effect instances on `target` with the same
-/// resource and effect as the ones the operation `iface` having on `source`.
-static void
-remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
-             SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  SmallVector<MemoryEffects::EffectInstance> nestedEffects;
-  iface.getEffectsOnValue(source, nestedEffects);
-  for (const auto &effect : nestedEffects)
-    effects.emplace_back(effect.getEffect(), target, effect.getResource());
-}
-
-namespace {
-template <typename T>
-using has_get_extra_bindings = decltype(std::declval<T &>().getExtraBindings());
-} // namespace
-
 /// Populate `effects` with transform dialect memory effects for the potential
 /// top-level operation. Such operations have recursive effects from nested
 /// operations. When they have an operand, we can additionally remap effects on
@@ -911,26 +1155,8 @@ static void getPotentialTopLevelEffects(
 
   // Carry over all effects on arguments of the entry block as those on the
   // operands, this is the same value just remapped.
-  for (Operation &op : *operation.getBodyBlock()) {
-    auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
-    if (!iface)
-      continue;
-
-    remapEffects(iface, operation.getBodyBlock()->getArgument(0),
-                 operation.getRoot(), effects);
-    if constexpr (llvm::is_detected<has_get_extra_bindings, OpTy>::value) {
-      for (auto [source, target] :
-           llvm::zip(operation.getBodyBlock()->getArguments().drop_front(),
-                     operation.getExtraBindings())) {
-        remapEffects(iface, source, target, effects);
-      }
-    }
-
-    SmallVector<MemoryEffects::EffectInstance> nestedEffects;
-    iface.getEffectsOnResource(transform::PayloadIRResource::get(),
-                               nestedEffects);
-    llvm::append_range(effects, nestedEffects);
-  }
+  remapArgumentEffects(*operation.getBodyBlock(), operation->getOperands(),
+                       effects);
 }
 
 void transform::SequenceOp::getEffects(
index 3d6ee21..40624e6 100644 (file)
@@ -83,6 +83,9 @@ static Operation *findTopLevelTransform(Operation *root,
   ::mlir::transform::TransformOpInterface topLevelTransform = nullptr;
   WalkResult walkResult = root->walk<WalkOrder::PreOrder>(
       [&](::mlir::transform::TransformOpInterface transformOp) {
+        if (!transformOp
+                 ->hasTrait<transform::PossibleTopLevelTransformOpTrait>())
+          return WalkResult::skip();
         if (!topLevelTransform) {
           topLevelTransform = transformOp;
           return WalkResult::skip();
index 4abaa23..ee03d9e 100644 (file)
@@ -284,3 +284,184 @@ transform.sequence failures(suppress) {
   // expected-note @below {{no 'allocate' effect specified for result #0}}
   transform.test_required_memory_effects %arg0 {has_operand_effect, modifies_payload} : (!transform.any_op) -> !transform.any_op
 }
+
+// -----
+
+// expected-error @below {{attribute can only be attached to operations with symbol tables}}
+"test.unknown_container"() { transform.with_named_sequence } : () -> ()
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  // expected-error @below {{failed to verify constraint: region with 1 blocks}}
+  "transform.named_sequence"() ({}) { sym_name = "external_named_sequence", function_type = () -> () } : () -> ()
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    transform.include @external_named_sequence failures(propagate) () : () -> ()
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  // expected-error @below {{recursion not allowed in named sequences}}
+  transform.named_sequence @self_recursion() -> () {
+    transform.include @self_recursion failures(suppress) () : () -> ()
+  }
+}
+
+// -----
+
+module @mutual_recursion attributes { transform.with_named_sequence } {
+  // expected-note @below {{operation on recursion stack}}  
+  transform.named_sequence @foo(%arg0: !transform.any_op) -> () {
+    transform.include @bar failures(suppress) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
+
+  // expected-error @below {{recursion not allowed in named sequences}}
+  transform.named_sequence @bar(%arg0: !transform.any_op) -> () {
+    transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
+}
+
+// -----
+
+// expected-error @below {{unknown attribute: "transform.unknown_container"}}
+module @unknown_attribute attributes { transform.unknown_container } {}
+
+// -----
+
+module {
+  transform.sequence failures(suppress) {
+  ^bb0(%arg0: !transform.any_op):
+    // expected-error @below {{op does not reference a named transform sequence}}
+    transform.include @non_existent failures(propagate) () : () -> ()
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.sequence failures(suppress) {
+  ^bb0(%arg0: !transform.any_op):
+    // expected-error @below {{requires attribute 'target'}}
+    "transform.include"() {failure_propagation_mode = 0} : () -> ()
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @foo(%arg0: !transform.any_op) -> () {
+    transform.yield
+  }
+
+  transform.sequence failures(suppress) {
+  ^bb0(%arg1: !transform.any_op):
+    // expected-error @below {{incorrect number of operands for callee}}
+    transform.include @foo failures(suppress) () : () -> ()
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @foo(%arg0: !transform.any_op) -> () {
+    transform.yield
+  }
+
+  transform.sequence failures(suppress) {
+  ^bb0(%arg1: !transform.op<"builtin.module">):
+    // expected-error @below {{operand type mismatch: expected operand type '!transform.any_op', but provided '!transform.op<"builtin.module">' for operand number 0}}
+    transform.include @foo failures(suppress) (%arg1) : (!transform.op<"builtin.module">) -> ()
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @foo(%arg0: !transform.any_op) -> (!transform.any_op) {
+    transform.yield %arg0 : !transform.any_op
+  }
+
+  transform.sequence failures(suppress) {
+  ^bb0(%arg1: !transform.any_op):
+    // expected-error @below {{incorrect number of results for callee}}
+    transform.include @foo failures(suppress) (%arg1) : (!transform.any_op) -> ()
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @foo(%arg0: !transform.any_op) -> (!transform.any_op) {
+    transform.yield %arg0 : !transform.any_op
+  }
+
+  transform.sequence failures(suppress) {
+  ^bb0(%arg1: !transform.any_op):
+    // expected-error @below {{type of result #0 must implement the same transform dialect interface as the corresponding callee result}}
+    transform.include @foo failures(suppress) (%arg1) : (!transform.any_op) -> (!transform.any_value)
+  }
+}
+
+// -----
+
+// expected-note @below {{symbol table operation}}
+module {
+  // expected-error @below {{expects the parent symbol table to have the 'transform.with_named_sequence' attribute}}
+  transform.named_sequence @parent_has_no_attributes() {
+    transform.yield
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence} {
+  // expected-note @below {{ancestor transform op}}
+  transform.sequence failures(suppress) {
+  ^bb0(%arg0: !transform.any_op):
+    // expected-error @below {{cannot be defined inside another transform op}}
+    transform.named_sequence @nested() {
+      transform.yield
+    }
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence} {
+  func.func private @foo()
+
+  // expected-error @below {{expected 'transform.yield' as terminator}}
+  transform.named_sequence @nested() {
+    // expected-note @below {{terminator}}
+    func.call @foo() : () -> ()
+  }
+}
+
+
+// -----
+
+module attributes { transform.with_named_sequence} {
+  func.func private @foo()
+
+  transform.named_sequence @nested(%arg0: !transform.any_op) {
+    // expected-error @below {{expected terminator to have as many operands as the parent op has results}}
+    transform.yield %arg0 : !transform.any_op
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence} {
+  func.func private @foo()
+
+  transform.named_sequence @nested(%arg0: !transform.any_op) -> !transform.op<"builtin.module"> {
+    // expected-error @below {{the type of the terminator operand #0 must match the type of the corresponding parent op result}}
+    transform.yield %arg0 : !transform.any_op
+  }
+}
index 3a7f420..6b2b0dd 100644 (file)
@@ -1255,3 +1255,82 @@ transform.sequence failures(propagate) {
   %op = transform.get_defining_op %bbarg : (!transform.any_value) -> !transform.any_op
   transform.test_print_remark_at_operand %op, "matched" : !transform.any_op
 }
+
+// -----
+
+module @named_inclusion attributes { transform.with_named_sequence } {
+
+  transform.named_sequence @foo(%arg0: !transform.any_op) -> () {
+    // expected-remark @below {{applying transformation "a"}}
+    transform.test_transform_op "a"
+    transform.yield
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+  }
+}
+
+// -----
+
+module @named_inclusion_in_named attributes { transform.with_named_sequence } {
+
+  transform.named_sequence @foo(%arg0: !transform.any_op) -> () {
+    // expected-remark @below {{applying transformation "a"}}
+    transform.test_transform_op "a"
+    transform.yield
+  }
+
+  transform.named_sequence @bar(%arg0: !transform.any_op) -> () {
+    // expected-remark @below {{applying transformation "b"}}
+    transform.test_transform_op "b"
+    transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    transform.include @bar failures(suppress) (%arg0) : (!transform.any_op) -> ()
+  }
+}
+
+// -----
+
+// expected-remark @below {{operation}}
+module @named_operands attributes { transform.with_named_sequence } {
+
+  transform.named_sequence @foo(%arg0: !transform.any_op, %arg1: !transform.any_value) -> () {
+    transform.test_print_remark_at_operand %arg0, "operation" : !transform.any_op
+    transform.test_print_remark_at_operand_value %arg1, "value" : !transform.any_value
+    transform.yield
+  }
+
+  transform.sequence failures(propagate) {
+  // expected-remark @below {{value}}
+  // expected-note @below {{value handle points to a block argument #0 in block #0 in region #0}}
+  ^bb0(%arg0: !transform.any_op):
+    %0 = transform.test_produce_value_handle_to_self_operand %arg0 : (!transform.any_op) -> !transform.any_value
+    include @foo failures(propagate) (%arg0, %0) : (!transform.any_op, !transform.any_value) -> ()
+  }
+}
+
+// -----
+
+// expected-remark @below {{operation}}
+module @named_return attributes { transform.with_named_sequence } {
+
+  // expected-remark @below {{value}}
+  // expected-note @below {{value handle points to a block argument #0 in block #0 in region #0}}
+  transform.named_sequence @foo(%arg0: !transform.any_op) -> (!transform.any_op, !transform.any_value) {
+    %0 = transform.test_produce_value_handle_to_self_operand %arg0 : (!transform.any_op) -> !transform.any_value
+    transform.yield %arg0, %0 : !transform.any_op, !transform.any_value
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    %0:2 = include @foo failures(propagate) (%arg0) : (!transform.any_op) -> (!transform.any_op, !transform.any_value)
+    transform.test_print_remark_at_operand %0#0, "operation" : !transform.any_op
+    transform.test_print_remark_at_operand_value %0#1, "value" : !transform.any_value
+  }
+}
index 3bef7dd..99a8653 100644 (file)
@@ -9316,7 +9316,10 @@ gentbl_cc_library(
     ],
     tblgen = ":mlir-tblgen",
     td_file = "include/mlir/Dialect/Transform/IR/TransformOps.td",
-    deps = [":TransformDialectTdFiles"],
+    deps = [
+        ":CallInterfacesTdFiles",
+        ":TransformDialectTdFiles"
+    ],
 )
 
 gentbl_cc_library(
@@ -9342,6 +9345,7 @@ cc_library(
     srcs = glob(["lib/Dialect/Transform/IR/*.cpp"]),
     hdrs = glob(["include/mlir/Dialect/Transform/IR/*.h"]),
     deps = [
+        ":CallInterfaces",
         ":ControlFlowInterfaces",
         ":IR",
         ":PDLDialect",