[mlir] add support for transform dialect value handles
authorAlex Zinenko <zinenko@google.com>
Fri, 3 Feb 2023 14:00:33 +0000 (14:00 +0000)
committerAlex Zinenko <zinenko@google.com>
Thu, 9 Feb 2023 12:11:24 +0000 (12:11 +0000)
Introduce support for the third kind of values in the transform dialect:
value handles. Similarly to operation handles, value handles are
pointing to a set of values in the payload IR. This enables
transformation to be targeted at specific values, such as individual
results of a multi-result payload operation without indirecting through
the producing op or block arguments that previously could not be easily
addressed. This is expected to support a broad class of memory-oriented
transformations such as selective bufferization, buffer assignment, and
memory transfer management.

Value handles are functionally similar to operation handles and require
similar implementation logic. The most important change concerns the
handle invalidation mechanism where operation and value handles can
affect each other.

This patch includes two cleanups that make it easier to introduce value
handles:

  - `RaggedArray` structure that encapsulates the SmallVector of
    ArrayRef backed by flat SmallVector logic, frequently used in the
    transform interfaces implementation;

  - rewrite the tests that associated payload handles with an integer
    value `reinterpret_cast`ed as a pointer, which were a frequent
    source of confusion and crashes when adding more debugging
    facilities that can inspect the payload.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D143385

22 files changed:
mlir/docs/Dialects/Transform.md
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h [new file with mode: 0644]
mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
mlir/test/Dialect/Linalg/transform-op-match.mlir
mlir/test/Dialect/Transform/check-use-after-free.mlir
mlir/test/Dialect/Transform/expensive-checks.mlir
mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir
mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir [new file with mode: 0644]
mlir/test/Dialect/Transform/ops-invalid.mlir
mlir/test/Dialect/Transform/test-dialect-injection.mlir
mlir/test/Dialect/Transform/test-interpreter.mlir
mlir/test/Dialect/Transform/transform-state-extension.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp

index e33604c..123d661 100644 (file)
@@ -25,20 +25,19 @@ of the IR using a different portion of the IR. It refers to the IR being
 transformed as payload IR, and to the IR guiding the transformation as
 transform IR.
 
-The main use case for this dialect is orchestrating fine-grain
-transformations on individual operations or sets thereof. For example, it
-may involve finding loop-like operations with specific properties (e.g.,
-large size) in the payload IR, applying loop tiling to those and only those
-operations, and then applying loop unrolling to the inner loops produced
-by the previous transformations. As such, it is not intended as a
-replacement for the pass infrastructure, nor for the pattern rewriting
-infrastructure. In the most common case, the transform IR will be processed
-and applied to the payload IR by a pass. Transformations expressed by the
-transform dialect may be implemented using the pattern infrastructure or any
-other relevant MLIR component.
+The main use case for this dialect is orchestrating fine-grain transformations
+on individual IR objects (operations or values) or sets thereof. For example, it
+may involve finding loop-like operations with specific properties (e.g., large
+size) in the payload IR, applying loop tiling to those and only those
+operations, and then applying loop unrolling to the inner loops produced by the
+previous transformations. As such, it is not intended as a replacement for the
+pass infrastructure, nor for the pattern rewriting infrastructure. In the most
+common case, the transform IR will be processed and applied to the payload IR by
+a pass. Transformations expressed by the transform dialect may be implemented
+using the pattern infrastructure or any other relevant MLIR component.
 
 The following IR gives a rough idea of what the operations in this dialect
-may look like:
+may look like without using actually existing operations:
 
 ```mlir
 %0 = transform.loop.find { size > 42 } : !transform.interface<tileable>
@@ -46,57 +45,70 @@ may look like:
 %2:2 = transform.loop.tile %0 tile_sizes(1, 4, %1)
       : (!transform.interface<tileable>)
      -> (!transform.op<loop>, !transform.op<loop>)
+%3 = transform.get_op_result [0] %2#0 : !transform.any_value
+transform.assign_to_fast_memory %3
 transform.loop.unroll %1#1 : !transform.op<loop>
 ```
 
-The values used in the Transform dialect may correspond to either:
+The values used in the Transform dialect may correspond to:
 
   * sets of operations in the payload IR;
 
+  * sets of values in the payload IR;
+
   * sets of parameters (attributes) known at the execution time of the
     transform dialect.
 
-The former kind of values is also referred to as *handles*. In the example
-above, `%0` corresponds to the set of loops found in the payload IR that
-satisfy the condition, and `%2` correspond to groups of outer and inner
-loops, respectively, produced by the tiling transformation, whereas `%1`
-corresponds to a list of tile sizes selected for each of the operations
-that `%0` corresponds to.
+The former two kinds of values are also referred to as operation and value
+*handles*, respectively. In the example above, `%0` corresponds to the set of
+loops found in the payload IR that satisfy the condition, and `%2` correspond to
+groups of outer and inner loops, respectively, produced by the tiling
+transformation. `%3` corresponds to a set of values that are produced by the
+outer loops after tiling. `%1` corresponds to a list of tile sizes selected for
+each of the operations that `%0` corresponds to.
 
-A transform handle such as `%0` may be associated with multiple payload
+An operation handle such as `%0` may be associated with multiple payload
 operations. This is conceptually a set of operations and no assumptions should
 be made about the order of ops unless specified otherwise by the operation.
-Operations may take as operands and produce an arbitrary combination of values
-representing handles and parameters. Most Transform IR ops support operand
-values that are mapped to multiple operations. They usually apply the respective
-transformation for every mapped op ("batched execution"). Deviations from this
-convention are described in the documentation of Transform IR ops.
-
-The transform IR values have transform IR types, which implement either
-[TransformHandleTypeInterface](Transform.md#transformhandletypeinterface-transformhandletypeinterface)
-or
-[TransformParamTypeInterface](Transform.md##transformparamtypeinterface-transformparamtypeinterface).
-The former interface verifiers properties of payload IR operations associated
-with the value that are known to the transform dialect, for example, all
-associated payload operations implement a "TileableOp" interface, or have a
-specific "loop" kind. Similarly, the latter interface verifies properties of
-attributes associated with the parameter value. These properties are used to
-statically indicate pre- and post-conditions of a transformation connected to a
-Transform dialect operation. The conditions are verified when attributes or
-payload IR operations are first associated with a transform handle. By
-convention, Transform dialect operations are expected to indicate narrow
-preconditions for their operands by enforcing operand type constraints in the
-their definitions and verifiers. On the contrary, operations are expected to
-have few constraints on their results. Specific instances of a transform
-operation can then be created with a more restricted result type than the
-constraint in the operation (e.g., the "find" operation only constrains the
-result type to be a transform IR type while its concrete instance can have a
-type with stricter constraints such as implementing the "tilable" interface).
-The verification will then happen at transform execution time. This approach
-allows one to capture payload IR operation properties in the transform IR
-without resorting to excessive use of type casts or coupling dialect extensions
-between themselves. It is a trade-off between verbosity/complexity and static
-hardening, which can be revised in the future.
+Similarly, a value handle such as `%3` may be associated with a set of payload
+IR values. Transform dialect operations may take as operands and produce an
+arbitrary combination of values representing handles and parameters. Most
+Transform IR ops support operand values that are mapped to multiple payload
+objects. They usually apply the respective transformation for every mapped
+object ("batched execution"). Deviations from this convention are described in
+the documentation of Transform IR ops.
+
+The transform IR values have transform IR types, which should implement exactly one of:
+
+  * [TransformHandleTypeInterface](Transform.md#transformhandletypeinterface-transformhandletypeinterface),
+
+  * [TransformValueHandleTypeInterface](Transform.md#transformvaluehandletypeinterface-transformvaluehandletypeinterface),
+
+  * [TransformParamTypeInterface](Transform.md##transformparamtypeinterface-transformparamtypeinterface).
+
+The goal of these type interfaces, beyond providing a common base for accepted
+types, is to verify the properties of the associated objects. For example, a
+handle type interface implementation may check whether all associated payload IR
+operations implement the "TileableOp" interface or have a specific "loop" kind.
+Similarly, a value handle type interface implementation may check if the
+associated payload IR values are block arguments or have a specific type, or a
+parameter type interface may check whether the associated attributes contain
+non-negative integer values. These properties are used to statically indicate
+ pre- and post-conditions of a transformation connected to a Transform dialect
+operation. The conditions are verified when payload objects operations are first
+associated with a transform handle. By convention, Transform dialect operations
+are expected to indicate narrow preconditions for their operands by enforcing
+operand type constraints in the their definitions and verifiers. On the
+contrary, operations are expected to have few constraints on their results.
+Specific instances of a transform operation can then be created with a more
+restricted result type than the constraint in the operation (e.g., the "find"
+operation only constrains the result type to be a transform IR type while its
+concrete instance can have a type with stricter constraints such as implementing
+the "tilable" interface). The verification will then happen at transform
+execution time. This approach allows one to capture payload IR operation
+properties in the transform IR without resorting to excessive use of type casts
+or coupling dialect extensions between themselves. It is a trade-off between
+verbosity/complexity and static hardening, which can be revised in the future.
 
 Overall, Transform IR ops are expected to be contained in a single top-level
 op. Such top-level ops specify how to apply the transformations described
@@ -111,7 +123,7 @@ programmatically triggered by calling:
 ```c++
 LogicalResult transform::applyTransforms(
     Operation *payloadRoot,
-    ArrayRef<ArrayRef<PointerUnion<Operation *, Attribute>> extraMappings,
+    const RaggedArray<transform::MappedValue> &extraMappings,
     TransformOpInterface transform,
     const TransformOptions &options);
 ```
@@ -163,7 +175,7 @@ Similarly to operations, additional types can be injected into the dialect using
 the same extension mechanism. The types must:
 
   * Implement exactly one of `TransformHandleTypeInterface`,
-    `TransformParamTypeInterface`.
+    `TransformValueHandleTypeInterface`, `TransformParamTypeInterface`.
 
 ## Side Effects
 
@@ -255,18 +267,57 @@ operation lists.
 
 ## Handle Invalidation
 
-The execution model of the transform dialect allows a payload IR operation
-to be associated with _multiple_ handles as well as nested payload IR
-operations to be associated with different handles. A transform IR operation
-that consumes a handle automatically _invalidates_ all the other handles
-associated with the same payload IR operations, or with any of their
-descendants, as the consumed handle. Note that the _entire_ handle is
-invalidated, even if some of the payload IR operations associated with it
-or their ancestors were not associated with the consumed handle. Any use of
-the invalidated handle results in undefined behavior since the payload IR
-operations associated with it are likely to have been mutated or erased. The
-mere fact of the handle being invalidated does _not_ trigger undefined
-behavior, only its appearance as an operand does.
+The execution model of the transform dialect allows a payload IR operation to be
+associated with _multiple_ handles as well as nested payload IR operations to be
+associated with different handles. Similarly, a payload IR value may be
+associated with multiple transform IR value handles. When a transform IR
+operation consumes a handle, it usually indicates that the corresponding payload
+IR object was destroyed and should no longer be referenced. Transform IR handles
+that _may_ be pointing to an erased payload IR object are _invalidated_. The
+mere presence of an invalidated handle in the transform IR is not a problem, but
+_using_ it results in undefined behavior. Invalidated handles can be thought of
+as dangling pointers. Note that the _entire_ handle is invalidated, even if some
+of the payload IR objects associated with it remain live.
+
+The following handle invalidation rules apply.
+
+  * When an operation handle is consumed, are invalidated:
+
+    - operation handles associated with one of the payload operations that the
+      consumed handle is associated with;
+
+    - operation handles associated with one of the operations _nested_ in the
+      payload operations described above;
+
+    - value handles associated with any result of any operation described above;
+    
+    - value handles associated with any argument of a block contained in a
+      region attached to any operation described above.
+
+  * When a value handle is consumed, are invalidated:
+
+    - operation handles associated with payload operations that produce as
+      result any value associated with the consumed handle (when the associated
+      is an operation result);
+
+    - operation handles associated with payload operations _nested_ in the
+      payload operations described above;
+
+    - operation handles associated with payload operations (recursively)
+      _contained_ in the block that defines as argument any value associated
+      with the consumed handle (when the associated value is a block argument);
+      note that the adjacent blocks are not affected;
+
+    - value handles associated with any result of any operation described above,
+      including all results of the operation defining as result the value
+      associated with the consumed handle;
+    
+    - value handles associated with any argument of a block contained in a
+      region attached to any operation described above.
+
+More intuitively, consuming a handle invalidates any handle that may be pointing
+to an object defined or contained in the payload IR subtree rooted at the
+closest operation or block.
 
 The Transform dialect infrastructure has the capability of checking whether
 the transform IR op operand is invalidated before applying the
index bae5d45..dc9612c 100644 (file)
@@ -11,8 +11,8 @@
 
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h"
+#include "mlir/Dialect/Transform/Utils/RaggedArray.h"
 #include "mlir/IR/OpDefinition.h"
-
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Support/LogicalResult.h"
 
@@ -45,7 +45,7 @@ private:
 };
 
 using Param = Attribute;
-using MappedValue = llvm::PointerUnion<Operation *, Param>;
+using MappedValue = llvm::PointerUnion<Operation *, Param, Value>;
 
 /// Entry point to the Transform dialect infrastructure. Applies the
 /// transformation specified by `transform` to payload IR contained in
@@ -55,7 +55,7 @@ using MappedValue = llvm::PointerUnion<Operation *, Param>;
 /// This function internally keeps track of the transformation state.
 LogicalResult
 applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
-                ArrayRef<ArrayRef<MappedValue>> extraMapping = {},
+                const RaggedArray<MappedValue> &extraMapping = {},
                 const TransformOptions &options = TransformOptions());
 
 /// The state maintained across applications of various ops implementing the
@@ -107,16 +107,22 @@ private:
   /// parameters.
   using ParamMapping = DenseMap<Value, SmallVector<Param>>;
 
+  /// Mapping between a Value in the transform IR and the corrsponding list of
+  /// values in the payload IR. Also works for reverse mappings.
+  using ValueMapping = DenseMap<Value, SmallVector<Value>>;
+
   /// The bidirectional mappings between transform IR values and payload IR
   /// operations, and the mapping between transform IR values and parameters.
   struct Mappings {
     TransformOpMapping direct;
     TransformOpReverseMapping reverse;
     ParamMapping params;
+    ValueMapping values;
+    ValueMapping reverseValues;
   };
 
   friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
-                                       ArrayRef<ArrayRef<MappedValue>>,
+                                       const RaggedArray<MappedValue> &,
                                        const TransformOptions &);
 
 public:
@@ -140,11 +146,21 @@ public:
   /// corresponds to.
   ArrayRef<Attribute> getParams(Value value) const;
 
+  /// Returns the list of payload IR values that the given transform IR value
+  /// corresponds to.
+  ArrayRef<Value> getPayloadValues(Value handleValue) const;
+
   /// Populates `handles` with all handles pointing to the given Payload IR op.
   /// Returns success if such handles exist, failure otherwise.
   LogicalResult getHandlesForPayloadOp(Operation *op,
                                        SmallVectorImpl<Value> &handles) const;
 
+  /// Populates `handles` with all handles pointing to the given payload IR
+  /// value. Returns success if such handles exist, failure otherwise.
+  LogicalResult
+  getHandlesForPayloadValue(Value payloadValue,
+                            SmallVectorImpl<Value> &handles) const;
+
   /// Applies the transformation specified by the given transform op and updates
   /// the state accordingly.
   DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform);
@@ -319,10 +335,10 @@ private:
   /// which may or may not contain the region with transform ops. Additional
   /// options can be provided through the trailing configuration object.
   TransformState(Region *region, Operation *payloadRoot,
-                 ArrayRef<ArrayRef<MappedValue>> extraMappings = {},
+                 const RaggedArray<MappedValue> &extraMappings = {},
                  const TransformOptions &options = TransformOptions());
 
-  /// Returns the mappings frame for the reigon in which the value is defined.
+  /// Returns the mappings frame for the region in which the value is defined.
   const Mappings &getMapping(Value value) const {
     return const_cast<TransformState *>(this)->getMapping(value);
   }
@@ -344,10 +360,6 @@ private:
     return it->second;
   }
 
-  /// Removes the mapping between the given payload IR operation and the given
-  /// transform IR value.
-  void dropReverseMapping(Mappings &mappings, Operation *op, Value value);
-
   /// Sets the payload IR ops associated with the given transform IR value
   /// (handle). A payload op may be associated multiple handles as long as
   /// at most one of them gets consumed by further transformations.
@@ -367,40 +379,111 @@ private:
   /// by side effects. Practically, a transformation consuming a handle means
   /// that the associated payload operation may no longer exist.
   ///
+  /// Similarly, operation handles may be invalidate and should not be used
+  /// after a transform that consumed a value handle pointing to a payload value
+  /// defined by the operation as either block argument or op result. For
+  /// example, in the following sequence, the last transform operation rewrites
+  /// the callee to not return a specified result:
+  ///
+  ///   %0 = transform.find_call "myfunc"
+  ///   %1 = transform.find_results_of_calling "myfunc"
+  ///   transform.drop_call_result_from_signature %1[0]
+  ///
+  /// which requires the call operations to be recreated. Therefore, the handle
+  /// %0 becomes associated with a dangling pointer and should not be used.
+  ///
   /// Returns failure if the payload does not satisfy the conditions associated
   /// with the type of the handle value. The value is expected to have a type
   /// implementing TransformHandleTypeInterface.
   LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
 
+  /// Sets the payload IR values association with the given transform IR value
+  /// (handle). A payload value may be associated with multiple handles as long
+  /// as at most one of them is consumed by further transformations. For
+  /// example, a hypothetical "get results of calls to function with the given
+  /// name" transform may be performed twice in a row producing handles pointing
+  /// to the same values:
+  ///
+  ///   %0 = transform.find_results_of_calling "myfunc"
+  ///   %1 = transform.find_results_of_calling "myfunc"
+  ///
+  /// which is valid by itself. However, calling a hypothetical "erase value
+  /// producer" transform on both handles:
+  ///
+  ///   transform.erase_value_produce %0
+  ///   transform.erase_value_produce %1
+  ///
+  /// is invalid provided the transformation "consumes" the handle as expressed
+  /// by side effects (which themselves reflect the semantics of the transform
+  /// erasing the producer and making the handle dangling). Practically, a
+  /// transformation consuming a handle means the associated payload value may
+  /// no longer exist.
+  ///
+  /// Similarly, value handles are invalidated and should not be used after a
+  /// transform that consumed an operation handle pointing to the payload IR
+  /// operation defining the values associated the value handle, as either block
+  /// arguments or op results, or any ancestor operation. For example,
+  ///
+  ///   %0 = transform.find_call "myfunc"
+  ///   %1 = transform.find_results_of_calling "myfunc"
+  ///   transform.rewrite_and_rename %0 { new_name = "func" }
+  ///
+  /// makes %1 unusable after the last transformation if it consumes %0. When an
+  /// operation handle is consumed, it usually indicates that the operation was
+  /// destroyed or heavily modified, meaning that the values it defines may no
+  /// longer exist.
+  ///
+  /// Returns failure if the payload values do not satisfy the conditions
+  /// associated with the type of the handle value. The value is expected to
+  /// have a type implementing TransformValueHandleTypeInterface.
+  LogicalResult setPayloadValues(Value handle, ValueRange payloadValues);
+
   /// Sets the parameters associated with the given transform IR value. Returns
   /// failure if the parameters do not satisfy the conditions associated with
   /// the type of the value. The value is expected to have a type implementing
   /// TransformParamTypeInterface.
   LogicalResult setParams(Value value, ArrayRef<Param> params);
 
-  /// Forgets the payload IR ops associated with the given transform IR value.
-  void removePayloadOps(Value value);
+  /// Forgets the payload IR ops associated with the given transform IR value,
+  /// as well as any association between value handles and the results of said
+  /// payload IR op.
+  void forgetMapping(Value opHandle, ValueRange origOpFlatResults);
+
+  void forgetValueMapping(Value valueHandle,
+                          ArrayRef<Operation *> payloadOperations);
 
   /// Updates the payload IR ops associated with the given transform IR value.
   /// The callback function is called once per associated operation and is
   /// expected to return the modified operation or nullptr. In the latter case,
   /// the corresponding operation is no longer associated with the transform IR
-  /// value.
+  /// value. Value handles associated with the results of the operation are
+  /// also updated to be associated with the results of the new operation. For
+  /// this reason, the new operation must have the same number of results.
   ///
   /// Returns failure if the payload does not satisfy the conditions associated
   /// with the type of the handle value.
-  LogicalResult
-  updatePayloadOps(Value value,
-                   function_ref<Operation *(Operation *)> callback);
+  LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
 
   /// If the operand is a handle consumed by the operation, i.e. has the "free"
   /// memory effect associated with it, identifies other handles that are
   /// pointing to payload IR operations nested in the operations pointed to by
   /// the consumed handle. Marks all such handles as invalidated to trigger
-  /// errors if they are used.
-  void recordHandleInvalidation(OpOperand &handle);
-  void recordHandleInvalidationOne(OpOperand &handle, Operation *payloadOp,
-                                   Value otherHandle);
+  /// errors if they are used. If `throughValue` is passed, record the fact that
+  /// an op handle was invalidated because a value handle associated with
+  /// results of the payload op or its block arguments was invalidated.
+  void recordOpHandleInvalidation(OpOperand &consumingHandle,
+                                  ArrayRef<Operation *> potentialAncestors,
+                                  Value throughValue = nullptr);
+  void recordOpHandleInvalidationOne(OpOperand &handle,
+                                     ArrayRef<Operation *> potentialAncestors,
+                                     Operation *payloadOp, Value otherHandle,
+                                     Value throughValue = nullptr);
+
+  void recordValueHandleInvalidationByOpHandleOne(
+      OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
+      Value payloadValue, Value valueHandle);
+
+  void recordValueHandleInvalidation(OpOperand &valueHandle);
 
   /// Checks that the operation does not use invalidated handles as operands.
   /// Reports errors and returns failure if it does. Otherwise, invalidates the
@@ -421,14 +504,10 @@ private:
   /// The top-level operation that contains all payload IR, typically a module.
   Operation *topLevel;
 
-  /// Storage for extra mapped values (payload operations or parameters) to be
+  /// Extra mapped values (payload operations, values or parameters) to be
   /// associated with additional entry block arguments of the top-level
-  /// transform operation. Each entry in `topLevelMappedValues` is a reference
-  /// to a contiguous block in `topLevelMappedValueStorage`.
-  // TODO: turn this into a proper named data structure, there are several more
-  // below.
-  SmallVector<ArrayRef<MappedValue>> topLevelMappedValues;
-  SmallVector<MappedValue> topLevelMappedValueStorage;
+  /// transform operation.
+  RaggedArray<MappedValue> topLevelMappedValues;
 
   /// Additional options controlling the transformation state behavior.
   TransformOptions options;
@@ -455,16 +534,23 @@ class TransformResults {
 public:
   /// Indicates that the result of the transform IR op at the given position
   /// corresponds to the given list of payload IR ops. Each result must be set
-  /// by the transformation exactly once. The value must have a type
-  /// implementing TransformHandleTypeInterface.
+  /// by the transformation exactly once in case of transformation succeeding.
+  /// The value must have a type implementing TransformHandleTypeInterface.
   void set(OpResult value, ArrayRef<Operation *> ops);
 
   /// Indicates that the result of the transform IR op at the given position
   /// corresponds to the given list of parameters. Each result must be set by
-  /// the transformation exactly once. The value must have a type implementing
-  /// TransformParamTypeInterface.
+  /// the transformation exactly once in case of transformation succeeding. The
+  /// value must have a type implementing TransformParamTypeInterface.
   void setParams(OpResult value, ArrayRef<TransformState::Param> params);
 
+  /// Indicates that the result of the transform IR op at the given position
+  /// corresponds to the given range of payload IR values. Each result must be
+  /// set by the transformation exactly once in case of transformation
+  /// succeeding. The value must have a type implementing
+  /// TransformValueHandleTypeInterface.
+  void setValues(OpResult handle, ValueRange values);
+
 private:
   /// Creates an instance of TransformResults that expects mappings for
   /// `numSegments` values, which may be associated with payload operations or
@@ -481,34 +567,34 @@ private:
   /// be associated with parameters.
   ArrayRef<TransformState::Param> getParams(unsigned resultNumber) const;
 
+  /// Gets the list of payload IR values associated with the result identified
+  /// by its number in the list of operation results. The result must have been
+  /// set to be associated with payload IR values.
+  ArrayRef<Value> getValues(unsigned resultNumber) const;
+
   /// Returns `true` if the result identified by its number in the list of
-  /// operation results is associated with a list of parameters, `false` if it
-  /// is associated with the list of payload IR operations.
+  /// operation results is associated with a list of parameters, `false`
+  /// otherwise.
   bool isParam(unsigned resultNumber) const;
 
   /// Returns `true` if the result identified by its number in the list of
+  /// operation results is associated with a list of payload IR value, `false`
+  /// otherwise.
+  bool isValue(unsigned resultNumber) const;
+
+  /// Returns `true` if the result identified by its number in the list of
   /// operation results is associated with something.
   bool isSet(unsigned resultNumber) const;
 
-  /// Storage for pointers to payload IR ops that are associated with results of
-  /// a transform IR op. `segments` contains as many entries as the transform IR
-  /// op has results, even if some of them are not associated with payload IR
-  /// operations. Each entry is a reference to a contiguous segment in the
-  /// `operations` list that contains the pointers to operations. This allows
-  /// for operations to be stored contiguously without nested vectors and for
-  /// different segments to be set in any order.
-  SmallVector<ArrayRef<Operation *>, 2> segments;
-  SmallVector<Operation *> operations;
-
-  /// Storage for parameters that are associated with results of the transform
-  /// IR op. `paramSegments` contains as many entries as the transform IR op has
-  /// results, even if some of them are not associated with parameters. Each
-  /// entry is a reference to a contiguous segment in the `params` list that
-  /// contains the actual parameters. This allows for parameters to be stored
-  /// contiguously without nested vectors and for different segments to be set
-  /// in any order.
-  SmallVector<ArrayRef<TransformState::Param>, 2> paramSegments;
-  SmallVector<TransformState::Param> params;
+  /// Pointers to payload IR ops that are associated with results of a transform
+  /// IR op.
+  RaggedArray<Operation *> operations;
+
+  /// Parameters that are associated with results of the transform IR op.
+  RaggedArray<Param> params;
+
+  /// Payload IR values that are associated with results of a transform IR op.
+  RaggedArray<Value> values;
 };
 
 TransformState::RegionScope TransformState::make_region_scope(Region &region) {
@@ -625,14 +711,14 @@ public:
 /// Side effect resource corresponding to the mapping between Transform IR
 /// values and Payload IR operations. An Allocate effect from this resource
 /// means creating a new mapping entry, it is always accompanied by a Write
-/// effet. A Read effect from this resource means accessing the mapping. A Free
+/// effect. A Read effect from this resource means accessing the mapping. A Free
 /// effect on this resource indicates the removal of the mapping entry,
 /// typically after a transformation that modifies the Payload IR operations
 /// associated with one of the Transform IR operation's operands. It is always
 /// accompanied by a Read effect. Read-after-Free and double-Free are not
 /// allowed (they would be problematic with "regular" memory effects too) as
 /// they indicate an attempt to access Payload IR operations that have been
-/// modified, potentially erased, by the previous tranfsormations.
+/// modified, potentially erased, by the previous transformations.
 // TODO: consider custom effects if these are not enabling generic passes such
 // as CSE/DCE to work.
 struct TransformMappingResource
@@ -769,7 +855,7 @@ namespace transform {
 
 /// A single result of applying a transform op with `ApplyEachOpTrait` to a
 /// single payload operation.
-using ApplyToEachResult = llvm::PointerUnion<Operation *, Attribute>;
+using ApplyToEachResult = MappedValue;
 
 /// A list of results of applying a transform op with `ApplyEachOpTrait` to a
 /// single payload operation, co-indexed with the results of the transform op.
@@ -793,6 +879,9 @@ public:
       if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
                                           Operation *>) {
         results.push_back(static_cast<Operation *>(element));
+      } else if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
+                                                 Value>) {
+        results.push_back(element.template get<Value>());
       } else {
         results.push_back(static_cast<Attribute>(element));
       }
@@ -800,8 +889,12 @@ public:
   }
 
   /// Appends an element to the list.
+  // Using ApplyToEachResult that can be implicitly constructed from a Value but
+  // not from a concrete Op that is implicitly convertible to a Value to avoid
+  // ambiguity.
   void push_back(Operation *op) { results.push_back(op); }
   void push_back(Attribute attr) { results.push_back(attr); }
+  void push_back(ApplyToEachResult r) { results.push_back(r); }
 
   /// Reserves space for `size` elements in the list.
   void reserve(unsigned size) { results.reserve(size); }
index 22c0c94..f443961 100644 (file)
@@ -137,10 +137,10 @@ def TransformHandleTypeInterface
     : TransformTypeInterfaceBase<"TransformHandleTypeInterface",
                                  "::mlir::Operation *"> {
   let description = [{
-    Types that can be used for the Transform dialect handle values. Such types
-    define the properties of Payload IR operations associated with the handle.
-    A user of such a handle can assume that these properties have been verified
-    for any Payload IR operation associated with it.
+    Types that can be used for the Transform dialect operation handle values.
+    Such types define the properties of Payload IR operations associated with
+    the handle. A user of such a handle can assume that these properties have
+    been verified for any Payload IR operation associated with it.
   }];
 }
 
@@ -155,9 +155,21 @@ def TransformParamTypeInterface
   }];
 }
 
+def TransformValueHandleTypeInterface
+    : TransformTypeInterfaceBase<"TransformValueHandleTypeInterface",
+                                 "::mlir::Value"> {
+  let description = [{
+    Types that can be used for the Transform dialect handle values pointing to
+    Payload IR values. Such types define the properties of Payload IR values
+    associated with the handle. Users of such a handle can assume that these
+    properties have been verified for any Payload IR value associated with it.
+  }];
+}
+
 def Transform_AnyHandleOrParamType
   : Type<Or<[TransformParamTypeInterface.predicate,
-             TransformHandleTypeInterface.predicate]>,
+             TransformHandleTypeInterface.predicate,
+             TransformValueHandleTypeInterface.predicate]>,
          "any transform handle or parameter">;
 
 def FunctionalStyleTransformOpTrait
index ebaf576..9eece0f 100644 (file)
@@ -52,6 +52,15 @@ def Transform_ParamType : TypeDef<Transform_Dialect, "Param",
   let genVerifyDecl = 1;
 }
 
+def Transform_AnyValue : TypeDef<Transform_Dialect, "AnyValue",
+    [DeclareTypeInterfaceMethods<TransformValueHandleTypeInterface>]> {
+  let description = [{
+    Transform IR value that can be associated with a list of Payload IR values.
+  }];
+  let mnemonic = "any_value";
+  let assemblyFormat = "";
+}
+
 class Transform_ConcreteOpType<string opname>
   : Type<And<[Transform_OperationType.predicate,
               CPred<"$_self.cast<::mlir::transform::OperationType>()"
index 1b1ad6a..0a60b4c 100644 (file)
@@ -40,7 +40,7 @@ interpreterBaseInitializeImpl(MLIRContext *context, StringRef transformFileName,
 LogicalResult interpreterBaseRunOnOperationImpl(
     Operation *target, StringRef passName,
     const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
-    ArrayRef<ArrayRef<MappedValue>> extraMappings,
+    const RaggedArray<MappedValue> &extraMappings,
     const TransformOptions &options,
     const Pass::Option<std::string> &transformFileName,
     const Pass::Option<std::string> &debugPayloadRootTag,
diff --git a/mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h b/mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h
new file mode 100644 (file)
index 0000000..c3c38d6
--- /dev/null
@@ -0,0 +1,92 @@
+//===- RaggedArray.h - 2D array with different inner lengths ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+/// A 2D array where each row may have different length. Elements of each row
+/// are stored contiguously, but rows don't have a fixed order in the storage.
+template <typename T>
+class RaggedArray {
+public:
+  /// Returns the number of rows in the 2D array.
+  size_t size() const { return slices.size(); }
+
+  /// Returns true if the are no rows in the 2D array. Note that an array with a
+  /// non-zero number of empty rows is *NOT* empty.
+  bool empty() const { return slices.empty(); }
+
+  /// Accesses `pos`-th row.
+  ArrayRef<T> operator[](size_t pos) const { return at(pos); }
+  ArrayRef<T> at(size_t pos) const { return slices[pos]; }
+  MutableArrayRef<T> operator[](size_t pos) { return at(pos); }
+  MutableArrayRef<T> at(size_t pos) { return slices[pos]; }
+
+  /// Iterator over rows.
+  auto begin() { return slices.begin(); }
+  auto begin() const { return slices.begin(); }
+  auto end() { return slices.end(); }
+  auto end() const { return slices.end(); }
+
+  /// Reserve space to store `size` rows with `nestedSize` elements each.
+  void reserve(size_t size, size_t nestedSize = 0) {
+    slices.reserve(size);
+    storage.reserve(size * nestedSize);
+  }
+
+  /// Appends the given range of elements as a new row to the 2D array. May
+  /// invalidate the end iterator.
+  template <typename Range>
+  void push_back(Range &&elements) {
+    slices.push_back(appendToStorage(std::forward<Range>(elements)));
+  }
+
+  /// Replaces the `pos`-th row in the 2D array with the given range of
+  /// elements. Invalidates iterators and references to `pos`-th and all
+  /// succeeding rows.
+  template <typename Range>
+  void replace(size_t pos, Range &&elements) {
+    auto from = slices[pos].data();
+    if (from != nullptr) {
+      auto to = std::next(from, slices[pos].size());
+      auto newFrom = storage.erase(from, to);
+      // Update the array refs after the underlying storage was shifted.
+      for (size_t i = pos + 1, e = size(); i < e; ++i) {
+        slices[i] = MutableArrayRef<T>(newFrom, slices[i].size());
+        std::advance(newFrom, slices[i].size());
+      }
+    }
+    slices[pos] = appendToStorage(std::forward<Range>(elements));
+  }
+
+  /// Appends `num` empty rows to the array.
+  void appendEmptyRows(size_t num) { slices.resize(slices.size() + num); }
+
+private:
+  /// Appends the given elements to the storage and returns an ArrayRef pointing
+  /// to them in the storage.
+  template <typename Range>
+  MutableArrayRef<T> appendToStorage(Range &&elements) {
+    size_t start = storage.size();
+    llvm::append_range(storage, std::forward<Range>(elements));
+    return MutableArrayRef<T>(storage).drop_front(start);
+  }
+
+  /// Outer elements of the ragged array. Each entry is a reference to a
+  /// contiguous segment in the `storage` list that contains the actual
+  /// elements. This allows for elements to be stored contiguously without
+  /// nested vectors and for different segments to be set or replaced in any
+  /// order.
+  SmallVector<MutableArrayRef<T>> slices;
+
+  /// Dense storage for ragged array elements.
+  SmallVector<T> storage;
+};
+} // namespace mlir
index fadc9ce..1f61ecd 100644 (file)
@@ -38,12 +38,14 @@ void transform::detail::checkImplementsTransformOpInterface(
 void transform::detail::checkImplementsTransformHandleTypeInterface(
     TypeID typeID, MLIRContext *context) {
   const auto &abstractType = AbstractType::lookup(typeID, context);
-  assert(
-      (abstractType.hasInterface(
-           TransformHandleTypeInterface::getInterfaceID()) ||
-       abstractType.hasInterface(
-           TransformParamTypeInterface::getInterfaceID())) &&
-      "expected Transform dialect type to implement one of the two interfaces");
+  assert((abstractType.hasInterface(
+              TransformHandleTypeInterface::getInterfaceID()) ||
+          abstractType.hasInterface(
+              TransformParamTypeInterface::getInterfaceID()) ||
+          abstractType.hasInterface(
+              TransformValueHandleTypeInterface::getInterfaceID())) &&
+         "expected Transform dialect type to implement one of the three "
+         "interfaces");
 }
 #endif // NDEBUG
 
index e14fca2..6b0f59d 100644 (file)
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/Debug.h"
@@ -29,17 +30,12 @@ constexpr const Value transform::TransformState::kTopLevelValue;
 
 transform::TransformState::TransformState(
     Region *region, Operation *payloadRoot,
-    ArrayRef<ArrayRef<MappedValue>> extraMappings,
+    const RaggedArray<MappedValue> &extraMappings,
     const TransformOptions &options)
     : topLevel(payloadRoot), options(options) {
   topLevelMappedValues.reserve(extraMappings.size());
-  for (ArrayRef<MappedValue> mapping : extraMappings) {
-    size_t start = topLevelMappedValueStorage.size();
-    llvm::append_range(topLevelMappedValueStorage, mapping);
-    topLevelMappedValues.push_back(
-        ArrayRef<MappedValue>(topLevelMappedValueStorage)
-            .slice(start, mapping.size()));
-  }
+  for (ArrayRef<MappedValue> mapping : extraMappings)
+    topLevelMappedValues.push_back(mapping);
 
   auto result = mappings.try_emplace(region);
   assert(result.second && "the region scope is already present");
@@ -55,16 +51,26 @@ ArrayRef<Operation *>
 transform::TransformState::getPayloadOps(Value value) const {
   const TransformOpMapping &operationMapping = getMapping(value).direct;
   auto iter = operationMapping.find(value);
-  assert(iter != operationMapping.end() &&
-         "cannot find mapping for payload handle (param handle provided?)");
+  assert(
+      iter != operationMapping.end() &&
+      "cannot find mapping for payload handle (param/value handle provided?)");
   return iter->getSecond();
 }
 
 ArrayRef<Attribute> transform::TransformState::getParams(Value value) const {
   const ParamMapping &mapping = getMapping(value).params;
   auto iter = mapping.find(value);
-  assert(iter != mapping.end() &&
-         "cannot find mapping for param handle (payload handle provided?)");
+  assert(iter != mapping.end() && "cannot find mapping for param handle "
+                                  "(operation/value handle provided?)");
+  return iter->getSecond();
+}
+
+ArrayRef<Value>
+transform::TransformState::getPayloadValues(Value handleValue) const {
+  const ValueMapping &mapping = getMapping(handleValue).values;
+  auto iter = mapping.find(handleValue);
+  assert(iter != mapping.end() && "cannot find mapping for value handle "
+                                  "(param/operation handle provided?)");
   return iter->getSecond();
 }
 
@@ -82,6 +88,20 @@ LogicalResult transform::TransformState::getHandlesForPayloadOp(
   return success(found);
 }
 
+LogicalResult transform::TransformState::getHandlesForPayloadValue(
+    Value payloadValue, SmallVectorImpl<Value> &handles) const {
+  bool found = false;
+  for (const Mappings &mapping : llvm::make_second_range(mappings)) {
+    auto iterator = mapping.reverseValues.find(payloadValue);
+    if (iterator != mapping.reverseValues.end()) {
+      llvm::append_range(handles, iterator->getSecond());
+      found = true;
+    }
+  }
+
+  return success(found);
+}
+
 LogicalResult
 transform::TransformState::mapBlockArgument(BlockArgument argument,
                                             ArrayRef<MappedValue> values) {
@@ -99,6 +119,20 @@ transform::TransformState::mapBlockArgument(BlockArgument argument,
     return setPayloadOps(argument, operations);
   }
 
+  if (argument.getType().isa<TransformValueHandleTypeInterface>()) {
+    SmallVector<Value> payloadValues;
+    payloadValues.reserve(values.size());
+    for (MappedValue value : values) {
+      if (auto v = value.dyn_cast<Value>()) {
+        payloadValues.push_back(v);
+        continue;
+      }
+      return emitError(argument.getLoc())
+             << "wrong kind of value provided for the top-level value handle";
+    }
+    return setPayloadValues(argument, payloadValues);
+  }
+
   assert(argument.getType().isa<TransformParamTypeInterface>() &&
          "unsupported kind of block argument");
   SmallVector<Param> parameters;
@@ -119,8 +153,8 @@ transform::TransformState::setPayloadOps(Value value,
                                          ArrayRef<Operation *> targets) {
   assert(value != kTopLevelValue &&
          "attempting to reset the transformation root");
-  assert(!value.getType().isa<TransformParamTypeInterface>() &&
-         "cannot associate payload ops with a value of parameter type");
+  assert(value.getType().isa<TransformHandleTypeInterface>() &&
+         "wrong handle type");
 
   for (Operation *target : targets) {
     if (target)
@@ -150,6 +184,41 @@ transform::TransformState::setPayloadOps(Value value,
   return success();
 }
 
+LogicalResult
+transform::TransformState::setPayloadValues(Value handle,
+                                            ValueRange payloadValues) {
+  assert(handle != nullptr && "attempting to set params for a null value");
+  assert(handle.getType().isa<TransformValueHandleTypeInterface>() &&
+         "wrong handle type");
+
+  for (Value payload : payloadValues) {
+    if (payload)
+      continue;
+    return emitError(handle.getLoc()) << "attempting to assign a null payload "
+                                         "value to this transform handle";
+  }
+
+  auto iface = handle.getType().cast<TransformValueHandleTypeInterface>();
+  SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
+  DiagnosedSilenceableFailure result =
+      iface.checkPayload(handle.getLoc(), payloadValueVector);
+  if (failed(result.checkAndReport()))
+    return failure();
+
+  Mappings &mappings = getMapping(handle);
+  bool inserted =
+      mappings.values.insert({handle, std::move(payloadValueVector)}).second;
+  assert(
+      inserted &&
+      "value handle is already associated with another list of payload values");
+  (void)inserted;
+
+  for (Value payload : payloadValues)
+    mappings.reverseValues[payload].push_back(handle);
+
+  return success();
+}
+
 LogicalResult transform::TransformState::setParams(Value value,
                                                    ArrayRef<Param> params) {
   assert(value != nullptr && "attempting to set params for a null value");
@@ -177,54 +246,146 @@ LogicalResult transform::TransformState::setParams(Value value,
   return success();
 }
 
-void transform::TransformState::dropReverseMapping(Mappings &mappings,
-                                                   Operation *op, Value value) {
-  auto it = mappings.reverse.find(op);
-  if (it == mappings.reverse.end())
+template <typename Mapping, typename Key, typename Mapped>
+void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
+  auto it = mapping.find(key);
+  if (it == mapping.end())
     return;
 
-  llvm::erase_value(it->getSecond(), value);
+  llvm::erase_value(it->getSecond(), mapped);
   if (it->getSecond().empty())
-    mappings.reverse.erase(it);
+    mapping.erase(it);
 }
 
-void transform::TransformState::removePayloadOps(Value value) {
-  Mappings &mappings = getMapping(value);
-  for (Operation *op : mappings.direct[value])
-    dropReverseMapping(mappings, op, value);
-  mappings.direct.erase(value);
+void transform::TransformState::forgetMapping(Value opHandle,
+                                              ValueRange origOpFlatResults) {
+  Mappings &mappings = getMapping(opHandle);
+  for (Operation *op : mappings.direct[opHandle])
+    dropMappingEntry(mappings.reverse, op, opHandle);
+  mappings.direct.erase(opHandle);
+
+  for (Value opResult : origOpFlatResults) {
+    SmallVector<Value> resultHandles;
+    (void)getHandlesForPayloadValue(opResult, resultHandles);
+    for (Value resultHandle : resultHandles) {
+      Mappings &localMappings = getMapping(resultHandle);
+      dropMappingEntry(localMappings.values, resultHandle, opResult);
+      dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
+    }
+  }
 }
 
-LogicalResult transform::TransformState::updatePayloadOps(
-    Value value, function_ref<Operation *(Operation *)> callback) {
-  Mappings &mappings = getMapping(value);
-  auto it = mappings.direct.find(value);
-  assert(it != mappings.direct.end() && "unknown handle");
-  SmallVector<Operation *, 2> &association = it->getSecond();
-  SmallVector<Operation *, 2> updated;
-  updated.reserve(association.size());
+void transform::TransformState::forgetValueMapping(
+    Value valueHandle, ArrayRef<Operation *> payloadOperations) {
+  Mappings &mappings = getMapping(valueHandle);
+  for (Value payloadValue : mappings.reverseValues[valueHandle])
+    dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
+  mappings.values.erase(valueHandle);
+
+  for (Operation *payloadOp : payloadOperations) {
+    SmallVector<Value> opHandles;
+    (void)getHandlesForPayloadOp(payloadOp, opHandles);
+    for (Value opHandle : opHandles) {
+      Mappings &localMappings = getMapping(opHandle);
+      dropMappingEntry(localMappings.direct, opHandle, payloadOp);
+      dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
+    }
+  }
+}
+
+LogicalResult
+transform::TransformState::replacePayloadOp(Operation *op,
+                                            Operation *replacement) {
+  // Drop the mapping between the op and all handles that point to it. Don't
+  // care if there are on such handles.
+  SmallVector<Value> opHandles;
+  (void)getHandlesForPayloadOp(op, opHandles);
+  for (Value handle : opHandles) {
+    Mappings &mappings = getMapping(handle);
+    dropMappingEntry(mappings.reverse, op, handle);
+  }
 
-  for (Operation *op : association) {
-    dropReverseMapping(mappings, op, value);
-    if (Operation *updatedOp = callback(op)) {
-      updated.push_back(updatedOp);
-      mappings.reverse[updatedOp].push_back(value);
+  // Drop the mapping between the op results and all value handles that point to
+  // them. Don't care if there are no such handles.
+  RaggedArray<Value> resultValueHandles;
+  for (Value opResult : op->getResults()) {
+    SmallVector<Value> valueHandles;
+    (void)getHandlesForPayloadValue(opResult, valueHandles);
+    for (Value handle : valueHandles) {
+      Mappings &localMappings = getMapping(handle);
+      dropMappingEntry(localMappings.reverseValues, opResult, handle);
     }
+    resultValueHandles.push_back(std::move(valueHandles));
   }
 
-  auto iface = value.getType().cast<TransformHandleTypeInterface>();
-  DiagnosedSilenceableFailure result =
-      iface.checkPayload(value.getLoc(), updated);
-  if (failed(result.checkAndReport()))
-    return failure();
+  // TODO: consider invalidating the handles to nested objects here.
+
+  // If replacing with null, that is erasing the mapping, drop the mapping
+  // between the handles and the IR objects and return.
+  if (!replacement) {
+    for (Value handle : opHandles) {
+      Mappings &mappings = getMapping(handle);
+      dropMappingEntry(mappings.direct, handle, op);
+    }
+    for (Value opResult : op->getResults()) {
+      SmallVector<Value> valueHandles;
+      (void)getHandlesForPayloadValue(opResult, valueHandles);
+      for (Value handle : valueHandles) {
+        Mappings &localMappings = getMapping(handle);
+        dropMappingEntry(localMappings.values, handle, opResult);
+      }
+    }
+    return success();
+  }
+
+  // Otherwise, replace the pointed-to object of all handles while preserving
+  // their relative order.
+  if (op->getNumResults() != replacement->getNumResults()) {
+    return emitError(op->getLoc())
+           << "cannot replace an op with another op producing a different "
+              "number of results while tracking handles";
+  }
+
+  // Replace the mapped operation if present.
+  for (Value handle : opHandles) {
+    Mappings &mappings = getMapping(handle);
+    auto it = mappings.direct.find(handle);
+    if (it == mappings.direct.end())
+      continue;
+
+    SmallVector<Operation *, 2> &association = it->getSecond();
+    // Note that an operation may be associated with the handle more than once.
+    for (Operation *&mapped : association) {
+      if (mapped == op)
+        mapped = replacement;
+    }
+    mappings.reverse[replacement].push_back(handle);
+  }
+
+  // Replace the mapped results of the operation.
+  for (auto [origResult, replacementResult, handleList] : llvm::zip(
+           op->getResults(), replacement->getResults(), resultValueHandles)) {
+    for (Value resultHandle : handleList) {
+      Mappings &mappings = getMapping(resultHandle);
+      auto it = mappings.values.find(resultHandle);
+      if (it == mappings.values.end())
+        continue;
+
+      SmallVector<Value> &association = it->getSecond();
+      for (Value &mapped : association) {
+        if (mapped == origResult)
+          mapped = replacementResult;
+      }
+      mappings.reverseValues[replacementResult].push_back(resultHandle);
+    }
+  }
 
-  it->second = updated;
   return success();
 }
 
-void transform::TransformState::recordHandleInvalidationOne(
-    OpOperand &handle, Operation *payloadOp, Value otherHandle) {
-  ArrayRef<Operation *> potentialAncestors = getPayloadOps(handle.get());
+void transform::TransformState::recordOpHandleInvalidationOne(
+    OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
+    Operation *payloadOp, Value otherHandle, Value throughValue) {
   // If the op is associated with invalidated handle, skip the check as it
   // may be reading invalid IR.
   if (invalidatedHandles.count(otherHandle))
@@ -240,10 +401,13 @@ void transform::TransformState::recordHandleInvalidationOne(
     // deleted before the lambda gets called.
     Location ancestorLoc = ancestor->getLoc();
     Location opLoc = payloadOp->getLoc();
-    Operation *owner = handle.getOwner();
-    unsigned operandNo = handle.getOperandNumber();
+    Operation *owner = consumingHandle.getOwner();
+    unsigned operandNo = consumingHandle.getOperandNumber();
+    std::optional<Location> throughValueLoc =
+        throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
     invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
-                                       otherHandle](Location currentLoc) {
+                                       otherHandle,
+                                       throughValueLoc](Location currentLoc) {
       InFlightDiagnostic diag = emitError(currentLoc)
                                 << "op uses a handle invalidated by a "
                                    "previously executed transform op";
@@ -251,19 +415,144 @@ void transform::TransformState::recordHandleInvalidationOne(
       diag.attachNote(owner->getLoc())
           << "invalidated by this transform op that consumes its operand #"
           << operandNo
-          << " and invalidates handles to payload ops nested in payload "
-             "ops associated with the consumed handle";
+          << " and invalidates all handles to payload IR entities associated "
+             "with this operand and entities nested in them";
       diag.attachNote(ancestorLoc) << "ancestor payload op";
       diag.attachNote(opLoc) << "nested payload op";
+      if (throughValueLoc) {
+        diag.attachNote(*throughValueLoc)
+            << "consumed handle points to this payload value";
+      }
     };
   }
 }
 
-void transform::TransformState::recordHandleInvalidation(OpOperand &handle) {
-  for (const Mappings &mapping : llvm::make_second_range(mappings))
-    for (const auto &[payloadOp, otherHandles] : mapping.reverse)
+void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
+    OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
+    Value payloadValue, Value valueHandle) {
+  // If the op is associated with invalidated handle, skip the check as it
+  // may be reading invalid IR.
+  if (invalidatedHandles.count(valueHandle))
+    return;
+
+  for (Operation *ancestor : potentialAncestors) {
+    Operation *definingOp;
+    std::optional<unsigned> resultNo = std::nullopt;
+    unsigned argumentNo, blockNo, regionNo;
+    if (auto opResult = payloadValue.dyn_cast<OpResult>()) {
+      definingOp = opResult.getOwner();
+      resultNo = opResult.getResultNumber();
+    } else {
+      auto arg = payloadValue.cast<BlockArgument>();
+      definingOp = arg.getParentBlock()->getParentOp();
+      argumentNo = arg.getArgNumber();
+      blockNo = std::distance(arg.getOwner()->getParent()->begin(),
+                              arg.getOwner()->getIterator());
+      regionNo = arg.getOwner()->getParent()->getRegionNumber();
+    }
+    assert(definingOp && "expected the value to be defined by an op as result "
+                         "or block argument");
+    if (!ancestor->isAncestor(definingOp))
+      continue;
+
+    Operation *owner = consumingHandle.getOwner();
+    unsigned operandNo = consumingHandle.getOperandNumber();
+    Location ancestorLoc = ancestor->getLoc();
+    Location opLoc = definingOp->getLoc();
+    Location valueLoc = payloadValue.getLoc();
+    invalidatedHandles[valueHandle] =
+        [valueHandle, owner, operandNo, resultNo, argumentNo, blockNo, regionNo,
+         ancestorLoc, opLoc, valueLoc](Location currentLoc) {
+          InFlightDiagnostic diag = emitError(currentLoc)
+                                    << "op uses a handle invalidated by a "
+                                       "previously executed transform op";
+          diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
+          diag.attachNote(owner->getLoc())
+              << "invalidated by this transform op that consumes its operand #"
+              << operandNo
+              << " and invalidates all handles to payload IR entities "
+                 "associated with this operand and entities nested in them";
+          diag.attachNote(ancestorLoc)
+              << "ancestor op associated with the consumed handle";
+          if (resultNo) {
+            diag.attachNote(opLoc)
+                << "op defining the value as result #" << *resultNo;
+          } else {
+            diag.attachNote(opLoc)
+                << "op defining the value as block argument #" << argumentNo
+                << " of block #" << blockNo << " in region #" << regionNo;
+          }
+          diag.attachNote(valueLoc) << "payload value";
+        };
+  }
+}
+
+void transform::TransformState::recordOpHandleInvalidation(
+    OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
+    Value throughValue) {
+  // Iterate over the mapping and invalidate aliasing handles. This is quite
+  // expensive and only necessary for error reporting in case of transform
+  // dialect misuse with dangling handles. Iteration over the handles is based
+  // on the assumption that the number of handles is significantly less than the
+  // number of IR objects (operations and values). Alternatively, we could walk
+  // the IR nested in each payload op associated with the given handle and look
+  // for handles associated with each operation and value.
+  for (const Mappings &mapping : llvm::make_second_range(mappings)) {
+    // Go over all op handle mappings and mark as invalidated any handle
+    // pointing to any of the payload ops associated with the given handle or
+    // any op nested in them.
+    for (const auto &[payloadOp, otherHandles] : mapping.reverse) {
       for (Value otherHandle : otherHandles)
-        recordHandleInvalidationOne(handle, payloadOp, otherHandle);
+        recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
+                                      otherHandle, throughValue);
+    }
+    // Go over all value handle mappings and mark as invalidated any handle
+    // pointing to any result of the payload op associated with the given handle
+    // or any op nested in them. Similarly invalidate handles to argument of
+    // blocks belonging to any region of any payload op associated with the
+    // given handle or any op nested in them.
+    for (const auto &[payloadValue, valueHandles] : mapping.reverseValues) {
+      for (Value valueHandle : valueHandles)
+        recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
+                                                   payloadValue, valueHandle);
+    }
+  }
+}
+
+void transform::TransformState::recordValueHandleInvalidation(
+    OpOperand &valueHandle) {
+  // Invalidate other handles to the same value.
+  for (Value payloadValue : getPayloadValues(valueHandle.get())) {
+    SmallVector<Value> otherValueHandles;
+    (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
+    for (Value otherHandle : otherValueHandles) {
+      Operation *owner = valueHandle.getOwner();
+      unsigned operandNo = valueHandle.getOperandNumber();
+      Location valueLoc = payloadValue.getLoc();
+      invalidatedHandles[otherHandle] = [otherHandle, owner, operandNo,
+                                         valueLoc](Location currentLoc) {
+        InFlightDiagnostic diag = emitError(currentLoc)
+                                  << "op uses a handle invalidated by a "
+                                     "previously executed transform op";
+        diag.attachNote(otherHandle.getLoc()) << "invalidated handle";
+        diag.attachNote(owner->getLoc())
+            << "invalidated by this transform op that consumes its operand #"
+            << operandNo
+            << " and invalidates handles to the same values as associated with "
+               "it";
+        diag.attachNote(valueLoc) << "payload value";
+      };
+    }
+
+    if (auto opResult = payloadValue.dyn_cast<OpResult>()) {
+      Operation *payloadOp = opResult.getOwner();
+      recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue);
+    } else {
+      auto arg = payloadValue.dyn_cast<BlockArgument>();
+      for (Operation &payloadOp : *arg.getOwner())
+        recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue);
+    }
+  }
 }
 
 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
@@ -287,13 +576,44 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
       return isa<MemoryEffects::Free>(effect.getEffect()) &&
              effect.getValue() == target.get();
     };
-    if (llvm::any_of(effects, consumesTarget))
-      recordHandleInvalidation(target);
+    if (llvm::any_of(effects, consumesTarget)) {
+      if (target.get().getType().isa<TransformHandleTypeInterface>()) {
+        ArrayRef<Operation *> payloadOps = getPayloadOps(target.get());
+        recordOpHandleInvalidation(target, payloadOps);
+      } else if (target.get()
+                     .getType()
+                     .isa<TransformValueHandleTypeInterface>()) {
+        recordValueHandleInvalidation(target);
+      }
+    }
   }
 
   return success();
 }
 
+template <typename T>
+DiagnosedSilenceableFailure
+checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
+                                  transform::TransformOpInterface transform,
+                                  unsigned operandNumber) {
+  DenseSet<T> seen;
+  for (T p : payload) {
+    if (!seen.insert(p).second) {
+      DiagnosedSilenceableFailure diag =
+          transform.emitSilenceableError()
+          << "a handle passed as operand #" << operandNumber
+          << " and consumed by this operation points to a payload "
+             "entity more than once";
+      if constexpr (std::is_pointer_v<T>)
+        diag.attachNote(p->getLoc()) << "repeated target op";
+      else
+        diag.attachNote(p.getLoc()) << "repeated target value";
+      return diag;
+    }
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
 DiagnosedSilenceableFailure
 transform::TransformState::applyTransform(TransformOpInterface transform) {
   LLVM_DEBUG(DBGS() << "applying: " << transform << "\n");
@@ -313,25 +633,82 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
       if (!isHandleConsumed(operand.get(), transform))
         continue;
 
-      DenseSet<Operation *> seen;
-      for (Operation *op : getPayloadOps(operand.get())) {
-        if (!seen.insert(op).second) {
-          DiagnosedSilenceableFailure diag =
-              transform.emitSilenceableError()
-              << "a handle passed as operand #" << operand.getOperandNumber()
-              << " and consumed by this operation points to a payload "
-                 "operation more than once";
-          diag.attachNote(op->getLoc()) << "repeated target op";
-          return diag;
+      Type operandType = operand.get().getType();
+      if (operandType.isa<TransformHandleTypeInterface>()) {
+        DiagnosedSilenceableFailure check =
+            checkRepeatedConsumptionInOperand<Operation *>(
+                getPayloadOps(operand.get()), transform,
+                operand.getOperandNumber());
+        if (!check.succeeded())
+          return check;
+      } else if (operandType.isa<TransformValueHandleTypeInterface>()) {
+        DiagnosedSilenceableFailure check =
+            checkRepeatedConsumptionInOperand<Value>(
+                getPayloadValues(operand.get()), transform,
+                operand.getOperandNumber());
+        if (!check.succeeded())
+          return check;
+      }
+    }
+  }
+
+  // Find which operands are consumed.
+  DenseSet<unsigned> consumedOperands;
+  auto memEffectInterface =
+      cast<MemoryEffectOpInterface>(transform.getOperation());
+  SmallVector<MemoryEffects::EffectInstance, 2> effects;
+  for (OpOperand &target : transform->getOpOperands()) {
+    effects.clear();
+    memEffectInterface.getEffectsOnValue(target.get(), effects);
+    if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
+          return isa<transform::TransformMappingResource>(
+                     effect.getResource()) &&
+                 isa<MemoryEffects::Free>(effect.getEffect());
+        })) {
+      consumedOperands.insert(target.getOperandNumber());
+    }
+  }
+
+  // Remember the results of the payload ops associated with the consumed
+  // op handles or the ops defining the value handles so we can drop the
+  // association with them later. This must happen here because the
+  // transformation may destroy or mutate them so we cannot traverse the payload
+  // IR after that.
+  SmallVector<Value> origOpFlatResults;
+  SmallVector<Operation *> origAssociatedOps;
+  for (unsigned index : consumedOperands) {
+    Value operand = transform->getOperand(index);
+    if (operand.getType().isa<TransformHandleTypeInterface>()) {
+      for (Operation *payloadOp : getPayloadOps(operand))
+        llvm::append_range(origOpFlatResults, payloadOp->getResults());
+      continue;
+    }
+    if (operand.getType().isa<TransformValueHandleTypeInterface>()) {
+      for (Value payloadValue : getPayloadValues(operand)) {
+        if (payloadValue.isa<OpResult>()) {
+          origAssociatedOps.push_back(payloadValue.getDefiningOp());
+          continue;
         }
+        llvm::append_range(
+            origAssociatedOps,
+            llvm::map_range(*payloadValue.cast<BlockArgument>().getOwner(),
+                            [](Operation &op) { return &op; }));
       }
+      continue;
     }
+    DiagnosedDefiniteFailure diag =
+        emitDefiniteFailure(transform->getLoc())
+        << "unexpectedly consumed a value that is not a handle as operand #"
+        << index;
+    diag.attachNote(operand.getLoc())
+        << "value defined here with type " << operand.getType();
+    return diag;
   }
 
-  transform::TransformResults results(transform->getNumResults());
   // Compute the result but do not short-circuit the silenceable failure case as
   // we still want the handles to propagate properly so the "suppress" mode can
   // proceed on a best effort basis.
+  transform::TransformResults results(transform->getNumResults());
   DiagnosedSilenceableFailure result(transform.apply(results, *this));
   if (result.isDefiniteFailure())
     return result;
@@ -352,18 +729,12 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
 
   // Remove the mapping for the operand if it is consumed by the operation. This
   // allows us to catch use-after-free with assertions later on.
-  auto memEffectInterface =
-      cast<MemoryEffectOpInterface>(transform.getOperation());
-  SmallVector<MemoryEffects::EffectInstance, 2> effects;
-  for (OpOperand &target : transform->getOpOperands()) {
-    effects.clear();
-    memEffectInterface.getEffectsOnValue(target.get(), effects);
-    if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
-          return isa<transform::TransformMappingResource>(
-                     effect.getResource()) &&
-                 isa<MemoryEffects::Free>(effect.getEffect());
-        })) {
-      removePayloadOps(target.get());
+  for (unsigned index : consumedOperands) {
+    Value operand = transform->getOperand(index);
+    if (operand.getType().isa<TransformHandleTypeInterface>()) {
+      forgetMapping(operand, origOpFlatResults);
+    } else if (operand.getType().isa<TransformValueHandleTypeInterface>()) {
+      forgetValueMapping(operand, origAssociatedOps);
     }
   }
 
@@ -378,6 +749,13 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
               setParams(result, results.getParams(result.getResultNumber())))) {
         return DiagnosedSilenceableFailure::definiteFailure();
       }
+    } else if (result.getType().isa<TransformValueHandleTypeInterface>()) {
+      assert(results.isValue(result.getResultNumber()) &&
+             "expected values for value-type-result");
+      if (failed(setPayloadValues(
+              result, results.getValues(result.getResultNumber())))) {
+        return DiagnosedSilenceableFailure::definiteFailure();
+      }
     } else {
       assert(!results.isParam(result.getResultNumber()) &&
              "expected payload ops for the non-parameter typed result");
@@ -409,15 +787,9 @@ transform::TransformState::Extension::replacePayloadOp(Operation *op,
   if (failed(state.getHandlesForPayloadOp(op, handles)))
     return failure();
 
-  for (Value handle : handles) {
-    LogicalResult result =
-        state.updatePayloadOps(handle, [&](Operation *current) {
-          return current == op ? replacement : current;
-        });
-    if (failed(result))
-      return failure();
-  }
-  return success();
+  // TODO: we may need to invalidate handles to operations and values nested in
+  // the operation being replaced.
+  return state.replacePayloadOp(op, replacement);
 }
 
 //===----------------------------------------------------------------------===//
@@ -425,63 +797,95 @@ transform::TransformState::Extension::replacePayloadOp(Operation *op,
 //===----------------------------------------------------------------------===//
 
 transform::TransformResults::TransformResults(unsigned numSegments) {
-  segments.resize(numSegments,
-                  ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
-  paramSegments.resize(numSegments, ArrayRef<TransformState::Param>(
-                                        nullptr, static_cast<size_t>(0)));
+  operations.appendEmptyRows(numSegments);
+  params.appendEmptyRows(numSegments);
+  values.appendEmptyRows(numSegments);
 }
 
 void transform::TransformResults::set(OpResult value,
                                       ArrayRef<Operation *> ops) {
   int64_t position = value.getResultNumber();
-  assert(position < static_cast<int64_t>(segments.size()) &&
+  assert(position < static_cast<int64_t>(operations.size()) &&
          "setting results for a non-existent handle");
-  assert(segments[position].data() == nullptr && "results already set");
-  int64_t start = operations.size();
-  llvm::append_range(operations, ops);
-  segments[position] = ArrayRef(operations).drop_front(start);
+  assert(operations[position].data() == nullptr && "results already set");
+  assert(params[position].data() == nullptr &&
+         "another kind of results already set");
+  assert(values[position].data() == nullptr &&
+         "another kind of results already set");
+  operations.replace(position, ops);
 }
 
 void transform::TransformResults::setParams(
     OpResult value, ArrayRef<transform::TransformState::Param> params) {
   int64_t position = value.getResultNumber();
-  assert(position < static_cast<int64_t>(paramSegments.size()) &&
+  assert(position < static_cast<int64_t>(this->params.size()) &&
          "setting params for a non-existent handle");
-  assert(paramSegments[position].data() == nullptr && "params already set");
-  size_t start = this->params.size();
-  llvm::append_range(this->params, params);
-  paramSegments[position] = ArrayRef(this->params).drop_front(start);
+  assert(this->params[position].data() == nullptr && "params already set");
+  assert(operations[position].data() == nullptr &&
+         "another kind of results already set");
+  assert(values[position].data() == nullptr &&
+         "another kind of results already set");
+  this->params.replace(position, params);
+}
+
+void transform::TransformResults::setValues(OpResult handle,
+                                            ValueRange values) {
+  int64_t position = handle.getResultNumber();
+  assert(position < static_cast<int64_t>(values.size()) &&
+         "setting values for a non-existent handle");
+  assert(this->values[position].data() == nullptr && "values already set");
+  assert(operations[position].data() == nullptr &&
+         "another kind of results already set");
+  assert(params[position].data() == nullptr &&
+         "another kind of results already set");
+  this->values.replace(position, values);
 }
 
 ArrayRef<Operation *>
 transform::TransformResults::get(unsigned resultNumber) const {
-  assert(resultNumber < segments.size() &&
+  assert(resultNumber < operations.size() &&
          "querying results for a non-existent handle");
-  assert(segments[resultNumber].data() != nullptr &&
-         "querying unset results (param expected?)");
-  return segments[resultNumber];
+  assert(operations[resultNumber].data() != nullptr &&
+         "querying unset results (values or params expected?)");
+  return operations[resultNumber];
 }
 
 ArrayRef<transform::TransformState::Param>
 transform::TransformResults::getParams(unsigned resultNumber) const {
-  assert(resultNumber < paramSegments.size() &&
+  assert(resultNumber < params.size() &&
          "querying params for a non-existent handle");
-  assert(paramSegments[resultNumber].data() != nullptr &&
-         "querying unset params (payload ops expected?)");
-  return paramSegments[resultNumber];
+  assert(params[resultNumber].data() != nullptr &&
+         "querying unset params (ops or values expected?)");
+  return params[resultNumber];
+}
+
+ArrayRef<Value>
+transform::TransformResults::getValues(unsigned resultNumber) const {
+  assert(resultNumber < params.size() &&
+         "querying params for a non-existent handle");
+  assert(values[resultNumber].data() != nullptr &&
+         "querying unset values (ops or params expected?)");
+  return values[resultNumber];
 }
 
 bool transform::TransformResults::isParam(unsigned resultNumber) const {
-  assert(resultNumber < paramSegments.size() &&
+  assert(resultNumber < params.size() &&
          "querying association for a non-existent handle");
-  return paramSegments[resultNumber].data() != nullptr;
+  return params[resultNumber].data() != nullptr;
+}
+
+bool transform::TransformResults::isValue(unsigned resultNumber) const {
+  assert(resultNumber < values.size() &&
+         "querying association for a non-existent handle");
+  return values[resultNumber].data() != nullptr;
 }
 
 bool transform::TransformResults::isSet(unsigned resultNumber) const {
-  assert(resultNumber < paramSegments.size() &&
+  assert(resultNumber < params.size() &&
          "querying association for a non-existent handle");
-  return paramSegments[resultNumber].data() != nullptr ||
-         segments[resultNumber].data() != nullptr;
+  return params[resultNumber].data() != nullptr ||
+         operations[resultNumber].data() != nullptr ||
+         values[resultNumber].data() != nullptr;
 }
 
 //===----------------------------------------------------------------------===//
@@ -547,6 +951,12 @@ void transform::detail::setApplyToOneResults(
             return oneResult[r.getResultNumber()].get<Attribute>();
           }));
       transformResults.setParams(r, params);
+    } else if (r.getType().isa<TransformValueHandleTypeInterface>()) {
+      auto values = llvm::to_vector(
+          llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) {
+            return oneResult[r.getResultNumber()].get<Value>();
+          }));
+      transformResults.setValues(r, values);
     } else {
       auto payloads = llvm::to_vector(
           llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) {
@@ -571,6 +981,8 @@ LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
       SmallVector<MappedValue> &mapped = extraMappings.emplace_back();
       if (operand.getType().isa<TransformHandleTypeInterface>()) {
         llvm::append_range(mapped, state.getPayloadOps(operand));
+      } else if (operand.getType().isa<TransformValueHandleTypeInterface>()) {
+        llvm::append_range(mapped, state.getPayloadValues(operand));
       } else {
         assert(operand.getType().isa<TransformParamTypeInterface>() &&
                "unsupported kind of transform dialect value");
@@ -639,13 +1051,15 @@ transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
   }
   for (BlockArgument arg : body->getArguments().drop_front()) {
     if (arg.getType()
-            .isa<TransformHandleTypeInterface, TransformParamTypeInterface>())
+            .isa<TransformHandleTypeInterface, TransformParamTypeInterface,
+                 TransformValueHandleTypeInterface>())
       continue;
 
     InFlightDiagnostic diag =
         op->emitOpError()
         << "expects trailing entry block arguments to be of type implementing "
-           "TransformHandleTypeInterface or TransformParamTypeInterface";
+           "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
+           "TransformParamTypeInterface";
     diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
     return diag;
   }
@@ -675,7 +1089,9 @@ void transform::detail::getParamProducerTransformOpTraitEffects(
   bool hasPayloadOperands = false;
   for (Value operand : op->getOperands()) {
     onlyReadsHandle(operand, effects);
-    if (operand.getType().isa<TransformHandleTypeInterface>())
+    if (operand.getType()
+            .isa<TransformHandleTypeInterface,
+                 TransformValueHandleTypeInterface>())
       hasPayloadOperands = true;
   }
   if (hasPayloadOperands)
@@ -841,7 +1257,7 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
 LogicalResult
 transform::applyTransforms(Operation *payloadRoot,
                            TransformOpInterface transform,
-                           ArrayRef<ArrayRef<MappedValue>> extraMapping,
+                           const RaggedArray<MappedValue> &extraMapping,
                            const TransformOptions &options) {
 #ifndef NDEBUG
   if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
index 4e4fb2d..579e106 100644 (file)
@@ -99,3 +99,13 @@ transform::ParamType::checkPayload(Location loc,
   }
   return DiagnosedSilenceableFailure::success();
 }
+
+//===----------------------------------------------------------------------===//
+// transform::AnyValueType
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::AnyValueType::checkPayload(Location loc,
+                                      ArrayRef<Value> payload) const {
+  return DiagnosedSilenceableFailure::success();
+}
index a7456e3..3d6ee21 100644 (file)
@@ -279,7 +279,7 @@ static void performOptionalDebugActions(
 LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     Operation *target, StringRef passName,
     const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
-    ArrayRef<ArrayRef<MappedValue>> extraMappings,
+    const RaggedArray<MappedValue> &extraMappings,
     const TransformOptions &options,
     const Pass::Option<std::string> &transformFileName,
     const Pass::Option<std::string> &debugPayloadRootTag,
index e41ffec..cbc62a6 100644 (file)
@@ -13,11 +13,11 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %match_name = transform.structured.match ops{["arith.constant"]} in %arg1 : (!pdl.operation) -> !pdl.operation
   transform.test_print_remark_at_operand %match_name, "matched op name" : !pdl.operation
-  transform.test_consume_operand %match_name
+  transform.test_consume_operand %match_name : !pdl.operation
 
   %match_attr = transform.structured.match ops{["arith.constant"]} attributes{my_attr} in %arg1 : (!pdl.operation) -> !pdl.operation
   transform.test_print_remark_at_operand %match_attr, "matched attr name" : !pdl.operation
-  transform.test_consume_operand %match_attr
+  transform.test_consume_operand %match_attr : !pdl.operation
 }
 
 // -----
@@ -34,7 +34,7 @@ transform.sequence failures(propagate) {
   %match_name = transform.structured.match
     ops{["arith.constant"]} filter_result_type = f32 in %arg1 : (!pdl.operation) -> !pdl.operation
   transform.test_print_remark_at_operand %match_name, "matched op name" : !pdl.operation
-  transform.test_consume_operand %match_name
+  transform.test_consume_operand %match_name : !pdl.operation
 }
 
 // -----
@@ -65,7 +65,7 @@ transform.sequence failures(propagate) {
         #linalg.iterator_type<parallel>]}
       in %arg1 : (!pdl.operation) -> !pdl.operation
   transform.test_print_remark_at_operand %match_attr, "matched complex attr" : !pdl.operation
-  transform.test_consume_operand %match_attr
+  transform.test_consume_operand %match_attr : !pdl.operation
 
   %no_match = transform.structured.match
       attributes{iterator_types = [
index 8076163..e2c0f07 100644 (file)
@@ -2,7 +2,7 @@
 
 func.func @use_after_free_branching_control_flow() {
   // expected-note @below {{allocated here}}
-  %0 = transform.test_produce_param_or_forward_operand 42
+  %0 = transform.test_produce_self_handle_or_forward_operand
   transform.test_transform_op_with_regions {
     "transform.test_branching_transform_op_terminator"() : () -> ()
   },
@@ -11,7 +11,7 @@ func.func @use_after_free_branching_control_flow() {
     "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> ()
   ^bb1:
     // expected-note @below {{freed here}}
-    transform.test_consume_operand_if_matches_param_or_fail %0[42]
+    transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand"
     "transform.test_branching_transform_op_terminator"()[^bb3] : () -> ()
   ^bb2:
     "transform.test_branching_transform_op_terminator"()[^bb3] : () -> ()
@@ -29,7 +29,7 @@ func.func @use_after_free_branching_control_flow() {
 
 func.func @use_after_free_in_nested_op() {
   // expected-note @below {{allocated here}}
-  %0 = transform.test_produce_param_or_forward_operand 42
+  %0 = transform.test_produce_self_handle_or_forward_operand
   // expected-note @below {{freed here}}
   transform.test_transform_op_with_regions {
     "transform.test_branching_transform_op_terminator"() : () -> ()
@@ -38,7 +38,7 @@ func.func @use_after_free_in_nested_op() {
   ^bb0:
     "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> ()
   ^bb1:
-    transform.test_consume_operand_if_matches_param_or_fail %0[42]
+    transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand"
     "transform.test_branching_transform_op_terminator"()[^bb3] : () -> ()
   ^bb2:
     "transform.test_branching_transform_op_terminator"()[^bb3] : () -> ()
@@ -74,7 +74,7 @@ func.func @use_after_free_recursive_side_effects() {
     // expected-note @below {{freed here}}
     transform.sequence %0 : !pdl.operation failures(propagate) attributes { ord = 4 } {
     ^bb4(%arg4: !pdl.operation):
-      test_consume_operand_if_matches_param_or_fail %0[42]
+      test_consume_operand_of_op_kind_or_fail %0, "transform.sequence"
     }
     // expected-warning @below {{operand #0 may be used after free}}
     transform.sequence %0 : !pdl.operation failures(propagate) attributes { ord = 5 } {
@@ -102,7 +102,7 @@ func.func @use_after_free() {
     }
 
     // expected-note @below {{freed here}}
-    test_consume_operand_if_matches_param_or_fail %0[42]
+    test_consume_operand_of_op_kind_or_fail %0, "transform.sequence"
     // expected-warning @below {{operand #0 may be used after free}}
     transform.sequence %0 : !pdl.operation failures(propagate) attributes { ord = 5 } {
     ^bb3(%arg3: !pdl.operation):
@@ -118,7 +118,7 @@ func.func @use_after_free() {
 // be reported as use-after-free.
 func.func @use_after_free_self_cycle() {
   // expected-note @below {{allocated here}}
-  %0 = transform.test_produce_param_or_forward_operand 42
+  %0 = transform.test_produce_self_handle_or_forward_operand
   transform.test_transform_op_with_regions {
     "transform.test_branching_transform_op_terminator"() : () -> ()
   },
@@ -132,7 +132,7 @@ func.func @use_after_free_self_cycle() {
     }
     // expected-warning @below {{operand #0 may be used after free}}
     // expected-note @below {{freed here}}
-    transform.test_consume_operand_if_matches_param_or_fail %0[42]
+    transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand"
     "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> ()
   ^bb2:
     "transform.test_branching_transform_op_terminator"() : () -> ()
@@ -147,7 +147,7 @@ func.func @use_after_free_self_cycle() {
 // use-after-free.
 func.func @use_after_free_cycle() {
   // expected-note @below {{allocated here}}
-  %0 = transform.test_produce_param_or_forward_operand 42
+  %0 = transform.test_produce_self_handle_or_forward_operand
   transform.test_transform_op_with_regions {
     "transform.test_branching_transform_op_terminator"() : () -> ()
   },
@@ -157,7 +157,7 @@ func.func @use_after_free_cycle() {
   ^bb1:
     // expected-warning @below {{operand #0 may be used after free}}
     // expected-note @below {{freed here}}
-    transform.test_consume_operand_if_matches_param_or_fail %0[42]
+    transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand"
     "transform.test_branching_transform_op_terminator"()[^bb2, ^bb3] : () -> ()
   ^bb2:
     "transform.test_branching_transform_op_terminator"()[^bb1] : () -> ()
index abc09d7..de53554 100644 (file)
@@ -21,7 +21,7 @@ transform.with_pdl_patterns {
     %0 = pdl_match @return in %arg1 : (!pdl.operation) -> !pdl.operation
     %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
     // expected-note @below {{invalidated by this transform op that consumes its operand #0}}
-    test_consume_operand %1
+    test_consume_operand %1 : !pdl.operation
     // expected-error @below {{op uses a handle invalidated by a previously executed transform op}}
     test_print_remark_at_operand %0, "remark" : !pdl.operation
   }
@@ -55,8 +55,8 @@ transform.with_pdl_patterns {
     %0 = pdl_match @func in %arg1 : (!pdl.operation) -> !pdl.operation
     %1 = pdl_match @return in %arg1 : (!pdl.operation) -> !pdl.operation
     %2 = replicate num(%0) %1 : !pdl.operation, !pdl.operation
-    // expected-error @below {{a handle passed as operand #0 and consumed by this operation points to a payload operation more than once}}
-    test_consume_operand %2
+    // expected-error @below {{a handle passed as operand #0 and consumed by this operation points to a payload entity more than once}}
+    test_consume_operand %2 : !pdl.operation
     test_print_remark_at_operand %0, "remark" : !pdl.operation
   }
 }
@@ -74,9 +74,9 @@ module {
     // expected-note @below {{handle to invalidated ops}}
     %2 = transform.test_copy_payload %0
     // expected-note @below {{invalidated by this transform op that consumes its operand #0}}
-    transform.test_consume_operand %1
+    transform.test_consume_operand %1 : !pdl.operation
     // expected-error @below {{op uses a handle invalidated by a previously executed transform op}}
-    transform.test_consume_operand %2
+    transform.test_consume_operand %2 : !pdl.operation
   }
 }
 
@@ -95,8 +95,8 @@ module {
     // to overlapping sets of payload IR ops.
     //
     // expected-error @below {{op uses a handle invalidated by a previously executed transform op}}
-    // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates handles}}
-    transform.test_consume_operand %1, %2
+    // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities}}
+    transform.test_consume_operand %1, %2 : !pdl.operation
   }
 }
 
@@ -113,3 +113,221 @@ module {
     transform.merge_handles %1, %2 { deduplicate } : !pdl.operation
   }
 }
+// -----
+
+// expected-note @below {{payload value}}
+%0 = "test.match_anchor"() : () -> (i32)
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %2 = transform.structured.match ops{["test.match_anchor"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  %3 = test_produce_value_handle_to_result %2, 0 : (!transform.any_op) -> !transform.any_value
+  // expected-note @below {{invalidated handle}}
+  %4 = test_produce_value_handle_to_result %2, 0 : (!transform.any_op) -> !transform.any_value
+  // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates handles to the same values as associated with it}}
+  test_consume_operand %3 : !transform.any_value
+  // expected-error @below {{op uses a handle invalidated by a previously executed transform op}}
+  test_consume_operand %4 : !transform.any_value
+}
+
+// -----
+
+// expected-note @below {{ancestor op associated with the consumed handle}}
+// expected-note @below {{payload value}}
+// expected-note @below {{op defining the value as result #0}}
+%0 = "test.match_anchor"() : () -> (i32)
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %2 = transform.structured.match ops{["test.match_anchor"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  // expected-note @below {{invalidated handle}}
+  %3 = test_produce_value_handle_to_result %2, 0 : (!transform.any_op) -> !transform.any_value
+  // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}}
+  test_consume_operand %2 : !transform.any_op
+  // expected-error @below {{op uses a handle invalidated by a previously executed transform op}}
+  test_consume_operand %3 : !transform.any_value
+}
+
+// -----
+
+// expected-note @below {{ancestor op associated with the consumed handle}}
+"test.match_anchor_1"() ({
+^bb0:
+  // expected-note @below {{op defining the value as result #0}}
+  // expected-note @below {{payload value}}
+  %0 = "test.match_anchor_2"() : () -> (i32)
+  "test.region_terminator"() : () -> ()
+}) : () -> ()
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %1 = transform.structured.match ops{["test.match_anchor_1"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  // expected-note @below {{invalidated handle}}
+  %3 = test_produce_value_handle_to_result %2, 0 : (!transform.any_op) -> !transform.any_value
+  // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}}
+  test_consume_operand %1 : !transform.any_op
+  // expected-error @below {{op uses a handle invalidated by a previously executed transform op}}
+  test_consume_operand %3 : !transform.any_value
+}
+
+// -----
+
+// expected-note @below {{ancestor op associated with the consumed handle}}
+// expected-note @below {{op defining the value as block argument #0 of block #0 in region #0}}
+"test.match_anchor_1"() ({
+// expected-note @below {{payload value}}
+^bb0(%arg0: i32):
+  %0 = "test.match_anchor_2"() : () -> (i32)
+  "test.region_terminator"() : () -> ()
+}) : () -> ()
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %1 = transform.structured.match ops{["test.match_anchor_1"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  // expected-note @below {{invalidated handle}}
+  %3 = test_produce_value_handle_to_argument_of_parent_block %2, 0 : (!transform.any_op) -> !transform.any_value
+  // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}}
+  test_consume_operand %1 : !transform.any_op
+  // expected-error @below {{op uses a handle invalidated by a previously executed transform op}}
+  test_consume_operand %3 : !transform.any_value
+}
+
+// -----
+
+// expected-note @below {{ancestor op associated with the consumed handle}}
+"test.match_anchor_1"() ({
+^bb:
+  // expected-note @below {{op defining the value as block argument #0 of block #0 in region #0}}
+  "test.op_with_regions"() ({
+  // expected-note @below {{payload value}}
+  ^bb0(%arg0: i32):
+    %0 = "test.match_anchor_2"() : () -> (i32)
+    "test.region_terminator"() : () -> ()
+  }): () -> ()
+}) : () -> ()
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %1 = transform.structured.match ops{["test.match_anchor_1"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  // expected-note @below {{invalidated handle}}
+  %3 = test_produce_value_handle_to_argument_of_parent_block %2, 0 : (!transform.any_op) -> !transform.any_value
+  // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}}
+  test_consume_operand %1 : !transform.any_op
+  // expected-error @below {{op uses a handle invalidated by a previously executed transform op}}
+  test_consume_operand %3 : !transform.any_value
+}
+
+// -----
+
+// expected-note @below {{ancestor payload op}}
+// expected-note @below {{nested payload op}}
+// expected-note @below {{consumed handle points to this payload value}}
+%0 = "test.match_anchor"() : () -> (i32)
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  // expected-note @below {{handle to invalidated ops}}
+  %2 = transform.structured.match ops{["test.match_anchor"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  %3 = test_produce_value_handle_to_result %2, 0 : (!transform.any_op) -> !transform.any_value
+  // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}}
+  test_consume_operand %3 : !transform.any_value
+  // expected-error @below {{op uses a handle invalidated by a previously executed transform op}}
+  test_consume_operand %2 : !transform.any_op 
+}
+
+// -----
+
+// expected-note @below {{ancestor payload op}}
+// expected-note @below {{consumed handle points to this payload value}}
+%0 = "test.match_anchor_1"() ({
+^bb0:
+  // expected-note @below {{nested payload op}}
+  "test.match_anchor_2"() : () -> ()
+  "test.region_terminator"() : () -> ()
+}) : () -> (i32)
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %1 = transform.structured.match ops{["test.match_anchor_1"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  // expected-note @below {{handle to invalidated ops}}
+  %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  %3 = test_produce_value_handle_to_result %1, 0 : (!transform.any_op) -> !transform.any_value
+  // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}}
+  test_consume_operand %3 : !transform.any_value
+  // expected-error @below {{op uses a handle invalidated by a previously executed transform op}}
+  test_consume_operand %2 : !transform.any_op
+}
+
+
+// -----
+
+"test.match_anchor_1"() ({
+// expected-note @below {{consumed handle points to this payload value}}
+^bb0(%arg0: f32):
+  // expected-note @below {{ancestor payload op}}
+  // expected-note @below {{nested payload op}}
+  "test.match_anchor_2"() : () -> ()
+  "test.region_terminator"() : () -> ()
+}) : () -> ()
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  // expected-note @below {{handle to invalidated ops}}
+  %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  %3 = test_produce_value_handle_to_argument_of_parent_block %2, 0 : (!transform.any_op) -> !transform.any_value
+  // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}}
+  test_consume_operand %3 : !transform.any_value
+  // expected-error @below {{op uses a handle invalidated by a previously executed transform op}}
+  test_consume_operand %2 : !transform.any_op
+}
+
+// -----
+
+"test.op_with_regions"() ({
+// expected-note @below {{consumed handle points to this payload value}}
+^bb(%arg0: i32):
+  // expected-note @below {{ancestor payload op}}
+  "test.op_with_regions"() ({
+  ^bb0:
+    // expected-note @below {{nested payload op}}
+    "test.match_anchor_2"() : () -> ()
+    "test.region_terminator"() : () -> ()
+  }): () -> ()
+  "test.match_anchor_1"() : () -> ()
+}) : () -> ()
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %1 = transform.structured.match ops{["test.match_anchor_1"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  // expected-note @below {{handle to invalidated ops}}
+  %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  %3 = test_produce_value_handle_to_argument_of_parent_block %1, 0 : (!transform.any_op) -> !transform.any_value
+  // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}}
+  test_consume_operand %3 : !transform.any_value
+  // expected-error @below {{op uses a handle invalidated by a previously executed transform op}}
+  test_consume_operand %2 : !transform.any_op
+}
+
+// -----
+
+// Removing a block argument does not invalidate handles to operations in another block.
+// Not expecting an error here.
+
+"test.op_with_regions"() ({
+^bb1(%arg0: i32):
+  "test.match_anchor_1"() : () -> ()
+^bb2:
+  "test.match_anchor_2"() : () -> ()
+}) : () -> ()
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %1 = transform.structured.match ops{["test.match_anchor_1"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  %2 = transform.structured.match ops{["test.match_anchor_2"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  %3 = test_produce_value_handle_to_argument_of_parent_block %1, 0 : (!transform.any_op) -> !transform.any_value
+  test_consume_operand %3 : !transform.any_value
+  test_consume_operand %2 : !transform.any_op
+}
index 447c6b4..5147836 100644 (file)
@@ -37,6 +37,17 @@ func.func @foo() {
 
 // -----
 
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_value):
+  // expected-error @above {{wrong kind of value provided for the top-level value handle}}
+}
+
+func.func @foo() {
+  return
+}
+
+// -----
+
 // expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op):
diff --git a/mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir b/mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir
new file mode 100644 (file)
index 0000000..431b0c5
--- /dev/null
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-results-of-ops=test.some_returning_op bind-second-extra-to-results-of-ops=test.some_other_returning_op})' \
+// RUN:             --split-input-file --verify-diagnostics
+
+// Note that diagnostic checker will merge two diagnostics with the same message
+// at the same location, so only check the remark once.
+// 
+// expected-remark @below {{first extra}}
+// expected-note @below {{value handle points to an op result #0}}
+// expected-note @below {{value handle points to an op result #1}}
+%0:2 = "test.some_returning_op"() : () -> (i32, i64)
+
+// expected-remark @below {{first extra}}
+// expected-note @below {{value handle points to an op result #0}}
+%1 = "test.some_returning_op"() : () -> index
+
+// Note that diagnostic checker will merge two diagnostics with the same message
+// at the same location, so only check the remark once.
+// 
+// expected-remark @below {{second extra}}
+// expected-note @below {{value handle points to an op result #0}}
+// expected-note @below {{value handle points to an op result #1}}
+%2:2 = "test.some_other_returning_op"() : () -> (f32, f64)
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_value, %arg2: !transform.any_value):
+  test_print_remark_at_operand_value %arg1, "first extra" : !transform.any_value
+  test_print_remark_at_operand_value %arg2, "second extra" : !transform.any_value
+}
+
+// -----
+
+%0:2 = "test.some_returning_op"() : () -> (i32, i64)
+%1 = "test.some_returning_op"() : () -> index
+
+transform.sequence failures(propagate) {
+// expected-error @below {{wrong kind of value provided for top-level operation handle}}
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_value):
+}
+
+// -----
+
+// expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.any_value):
+}
index d2142db..4abaa23 100644 (file)
@@ -24,7 +24,7 @@ transform.sequence failures(propagate) {
 
 // -----
 
-// expected-error @below {{'transform.sequence' op expects trailing entry block arguments to be of type implementing TransformHandleTypeInterface or TransformParamTypeInterface}}
+// expected-error @below {{'transform.sequence' op expects trailing entry block arguments to be of type implementing TransformHandleTypeInterface, TransformValueHandleTypeInterface or TransformParamTypeInterface}}
 // expected-note @below {{argument #1 does not}}
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !transform.any_op, %arg1: i64):
@@ -166,11 +166,11 @@ transform.sequence failures(propagate) {
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
   // expected-error @below {{result #0 has more than one potential consumer}}
-  %0 = test_produce_param_or_forward_operand 42
+  %0 = test_produce_self_handle_or_forward_operand
   // expected-note @below {{used here as operand #0}}
-  test_consume_operand_if_matches_param_or_fail %0[42]
+  test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand"
   // expected-note @below {{used here as operand #0}}
-  test_consume_operand_if_matches_param_or_fail %0[42]
+  test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand"
 }
 
 // -----
@@ -178,13 +178,13 @@ transform.sequence failures(propagate) {
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
   // expected-error @below {{result #0 has more than one potential consumer}}
-  %0 = test_produce_param_or_forward_operand 42
+  %0 = test_produce_self_handle_or_forward_operand
   // expected-note @below {{used here as operand #0}}
-  test_consume_operand_if_matches_param_or_fail %0[42]
+  test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand"
   // expected-note @below {{used here as operand #0}}
   transform.sequence %0 : !pdl.operation failures(propagate) {
   ^bb1(%arg1: !pdl.operation):
-    test_consume_operand_if_matches_param_or_fail %arg1[42]
+    test_consume_operand_of_op_kind_or_fail %arg1, "transform.test_produce_self_handle_or_forward_operand"
   }
 }
 
@@ -193,13 +193,13 @@ transform.sequence failures(propagate) {
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
   // expected-error @below {{result #0 has more than one potential consumer}}
-  %0 = test_produce_param_or_forward_operand 42
+  %0 = test_produce_self_handle_or_forward_operand
   // expected-note @below {{used here as operand #0}}
-  test_consume_operand_if_matches_param_or_fail %0[42]
+  test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand"
   transform.sequence %0 : !pdl.operation failures(propagate) {
   ^bb1(%arg1: !pdl.operation):
     // expected-note @below {{used here as operand #0}}
-    test_consume_operand_if_matches_param_or_fail %0[42]
+    test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand"
   }
 }
 
@@ -208,15 +208,15 @@ transform.sequence failures(propagate) {
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
   // expected-error @below {{result #0 has more than one potential consumer}}
-  %0 = test_produce_param_or_forward_operand 42
+  %0 = test_produce_self_handle_or_forward_operand
   // expected-note @below {{used here as operand #0}}
-  test_consume_operand_if_matches_param_or_fail %0[42]
+  test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand"
   // expected-note @below {{used here as operand #0}}
   transform.sequence %0 : !pdl.operation failures(propagate) {
   ^bb1(%arg1: !pdl.operation):
     transform.sequence %arg1 : !pdl.operation failures(propagate) {
     ^bb2(%arg2: !pdl.operation):
-      test_consume_operand_if_matches_param_or_fail %arg2[42]
+      test_consume_operand_of_op_kind_or_fail %arg2, "transform.test_produce_self_handle_or_forward_operand"
     }
   }
 }
@@ -257,14 +257,14 @@ transform.alternatives {
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
   // expected-error @below {{result #0 has more than one potential consumer}}
-  %0 = test_produce_param_or_forward_operand 42
+  %0 = test_produce_self_handle_or_forward_operand
   // expected-note @below {{used here as operand #0}}
   transform.foreach %0 : !pdl.operation {
   ^bb1(%arg1: !pdl.operation):
-    transform.test_consume_operand %arg1
+    transform.test_consume_operand %arg1 : !pdl.operation
   }
   // expected-note @below {{used here as operand #0}}
-  transform.test_consume_operand %0
+  transform.test_consume_operand %0 : !pdl.operation
 }
 
 // -----
index f1c2762..5bbda8e 100644 (file)
@@ -6,11 +6,11 @@
 // CHECK: transform.test_transform_op
 transform.test_transform_op
 
-// CHECK: = transform.test_produce_param_or_forward_operand 42 {foo = "bar"}
-%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
+// CHECK: = transform.test_produce_self_handle_or_forward_operand {foo = "bar"}
+%0 = transform.test_produce_self_handle_or_forward_operand { foo = "bar" }
 
-// CHECK: transform.test_consume_operand_if_matches_param_or_fail %{{.*}}[42]
-transform.test_consume_operand_if_matches_param_or_fail %0[42]
+// CHECK: transform.test_consume_operand_of_op_kind_or_fail %{{.*}},
+transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand"
 
 // Ensure that the extension type is roundtripped correctly.
 // CHECK: transform.cast %{{.*}} : !pdl.operation to !transform.test_dialect_op
index f470606..e8bc530 100644 (file)
@@ -10,18 +10,18 @@ transform.sequence failures(propagate) {
 
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !transform.any_op):
-  %0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
+  %0 = transform.test_produce_self_handle_or_forward_operand { foo = "bar" }
   // expected-remark @below {{succeeded}}
-  transform.test_consume_operand_if_matches_param_or_fail %0[42]
+  transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand"
 }
 
 // -----
 
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !transform.any_op):
-  %0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
-  // expected-error @below {{expected the operand to be associated with 21 got 42}}
-  transform.test_consume_operand_if_matches_param_or_fail %0[21]
+  %0 = transform.test_produce_self_handle_or_forward_operand { foo = "bar" }
+  // expected-error @below {{expected the operand to be associated a payload op of kind transform.sequence got transform.test_produce_self_handle_or_forward_operand}}
+  transform.test_consume_operand_of_op_kind_or_fail %0, "transform.sequence"
 }
 
 // -----
@@ -31,10 +31,10 @@ transform.sequence failures(propagate) {
 // to detect double-consumption.
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !transform.any_op):
-  %0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
+  %0 = transform.test_produce_self_handle_or_forward_operand { foo = "bar" }
   %1 = transform.test_copy_payload %0
   // expected-remark @below {{succeeded}}
-  transform.test_consume_operand_if_matches_param_or_fail %0[42]
+  transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand"
 }
 
 // -----
@@ -60,11 +60,11 @@ transform.sequence failures(propagate) {
 
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
-  %0 = test_produce_param_or_forward_operand 42
+  %0 = test_produce_self_handle_or_forward_operand
   sequence %0 : !pdl.operation failures(propagate) {
   ^bb0(%arg1: !pdl.operation):
     // expected-remark @below {{succeeded}}
-    test_consume_operand_if_matches_param_or_fail %arg1[42]
+    test_consume_operand_of_op_kind_or_fail %arg1, "transform.test_produce_self_handle_or_forward_operand"
   }
 }
 
@@ -74,11 +74,11 @@ transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
   %0 = sequence %arg0 : !pdl.operation -> !pdl.operation failures(propagate) {
   ^bb0(%arg1: !pdl.operation):
-    %1 = test_produce_param_or_forward_operand 42
+    %1 = test_produce_self_handle_or_forward_operand
     yield %1 : !pdl.operation
   }
   // expected-remark @below {{succeeded}}
-  test_consume_operand_if_matches_param_or_fail %0[42]
+  test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand"
 }
 
 // -----
@@ -163,15 +163,15 @@ transform.with_pdl_patterns {
     %0 = pdl_match @match_func in %arg1 : (!pdl.operation) -> !pdl.operation
     transform.alternatives %0 : !pdl.operation {
     ^bb2(%arg2: !pdl.operation):
-      %1 = transform.test_produce_param_or_forward_operand 42
+      %1 = transform.test_produce_self_handle_or_forward_operand
       // This operation fails, which triggers the next alternative without
       // reporting the error.
-      transform.test_consume_operand_if_matches_param_or_fail %1[43]
+      transform.test_consume_operand_of_op_kind_or_fail %1, "transform.sequence"
     }, {
     ^bb2(%arg2: !pdl.operation):
-      %1 = transform.test_produce_param_or_forward_operand 42
+      %1 = transform.test_produce_self_handle_or_forward_operand
       // expected-remark @below {{succeeded}}
-      transform.test_consume_operand_if_matches_param_or_fail %1[42]
+      transform.test_consume_operand_of_op_kind_or_fail %1, "transform.test_produce_self_handle_or_forward_operand"
     }
   }
 }
@@ -315,17 +315,18 @@ transform.with_pdl_patterns {
       %3 = transform.pdl_match @match_call in %arg2 : (!pdl.operation) -> !pdl.operation
       // expected-remark @below {{applying}}
       transform.test_emit_remark_and_erase_operand %3, "applying" {fail_after_erase}
-      %4 = transform.test_produce_param_or_forward_operand 43
+      %4 = transform.test_produce_self_handle_or_forward_operand %3
       transform.yield %4 : !pdl.operation
     }, {
     ^bb2(%arg2: !pdl.operation):
-      %4 = transform.test_produce_param_or_forward_operand 42
+      %4 = transform.test_produce_self_handle_or_forward_operand
       transform.yield %4 : !pdl.operation
     }
     // The first alternative failed, so the returned value is taken from the
-    // second alternative.
+    // second alternative, associated test_produce_self_handle_or_forward_operand rather
+    // than pdl_match.
     // expected-remark @below {{succeeded}}
-    transform.test_consume_operand_if_matches_param_or_fail %2[42]
+    transform.test_consume_operand_of_op_kind_or_fail %2, "transform.test_produce_self_handle_or_forward_operand"
   }
 }
 
@@ -349,12 +350,12 @@ module {
     // expected-error @below {{scope must not contain the transforms being applied}}
     transform.alternatives %arg1 : !pdl.operation {
     ^bb2(%arg2: !pdl.operation):
-      %0 = transform.test_produce_param_or_forward_operand 42
-      transform.test_consume_operand_if_matches_param_or_fail %0[43]
+      %0 = transform.test_produce_self_handle_or_forward_operand
+      transform.test_consume_operand_of_op_kind_or_fail %0, "transform.sequence"
     }, {
     ^bb2(%arg2: !pdl.operation):
-      %0 = transform.test_produce_param_or_forward_operand 42
-      transform.test_consume_operand_if_matches_param_or_fail %0[42]
+      %0 = transform.test_produce_self_handle_or_forward_operand
+      transform.test_consume_operand_of_op_kind_or_fail %0, "transform.test_produce_self_handle_or_forward_operand"
     }
   }
 }
@@ -1094,6 +1095,14 @@ transform.sequence failures(propagate) {
 
 // -----
 
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  // expected-error @below {{attempting to assign a null payload value to this transform handle}}
+  %0 = transform.test_produce_null_value : !transform.any_value
+}
+
+// -----
+
 // expected-error @below {{could not find a nested top-level transform op}}
 // expected-note @below {{use the 'transform-file-name' option to provide transform as external file}}
 module {
@@ -1106,7 +1115,65 @@ transform.sequence failures(propagate) {
 ^bb0(%arg0: !transform.any_op):
 }
 
-// expected-error @below {{ore than one top-level transform op}}
+// expected-error @below {{more than one top-level transform op}}
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+// expected-remark @below {{value handle}}
+// expected-note @below {{value handle points to a block argument #0 in block #0 in region #0}}
+^bb1(%arg0: !transform.any_op):
+  %0 = test_produce_value_handle_to_self_operand %arg0 : (!transform.any_op) -> !transform.any_value
+  test_print_remark_at_operand_value %0, "value handle" : !transform.any_value
+}
+
+// -----
+
+// expected-remark @below {{result handle}}
+// expected-note @below {{value handle points to an op result #1}}
+%0:2 = "test.get_two_results"() : () -> (i32, i32)
+// expected-remark @below {{result handle}}
+// expected-note @below {{value handle points to an op result #1}}
+%1:3 = "test.get_three_results"() : () -> (i32, i32, f32)
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %2 = transform.structured.match ops{["test.get_two_results", "test.get_three_results"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  %3 = test_produce_value_handle_to_result %2, 1 : (!transform.any_op) -> !transform.any_value
+  test_print_remark_at_operand_value %3, "result handle" : !transform.any_value
+}
+
+// -----
+
+"test.op_with_regions"() ({
+^bb0:
+  "test.regon_terminator"() : () -> ()
+}, {
+^bb1:
+  "test.regon_terminator"() : () -> ()
+// expected-remark @below {{block argument handle}}
+// expected-note @below {{value handle points to a block argument #2 in block #1 in region #1}}
+^bb2(%arg0: i32, %arg1: f64, %arg3: index):
+  "test.match_anchor"() : () -> ()
+  "test.regon_terminator"() : () -> ()
+}) : () -> ()
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %2 = transform.structured.match ops{["test.match_anchor"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  %3 = test_produce_value_handle_to_argument_of_parent_block %2, 2 : (!transform.any_op) -> !transform.any_value
+  test_print_remark_at_operand_value %3, "block argument handle" : !transform.any_value
+}
+
+// -----
+
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !transform.any_op):
+  // expected-note @below {{value defined here with type '!transform.test_dialect_param'}}
+  %0 = test_produce_param_with_number_of_test_ops %arg0 : !transform.any_op
+  // expected-error @below {{unexpectedly consumed a value that is not a handle as operand #0}}
+  test_consume_operand %0 : !transform.test_dialect_param
 }
index 1f29684..054ee07 100644 (file)
@@ -47,6 +47,18 @@ module {
 
 // -----
 
+// expected-error @below {{cannot replace an op with another op producing a different number of results while tracking handles}}
+module {
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !pdl.operation):
+    test_add_test_extension "A"
+    %dummy = test_remap_operand_to_self %arg0 : !transform.any_op
+  }
+}
+
+
+// -----
+
 module {
   transform.sequence failures(suppress) {
   ^bb0(%arg0: !pdl.operation):
index 0bd3031..e32cb3e 100644 (file)
@@ -106,29 +106,73 @@ public:
 } // namespace
 
 DiagnosedSilenceableFailure
-mlir::test::TestProduceParamOrForwardOperandOp::apply(
+mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   if (getOperation()->getNumOperands() != 0) {
     results.set(getResult().cast<OpResult>(),
                 getOperation()->getOperand(0).getDefiningOp());
   } else {
-    results.set(getResult().cast<OpResult>(),
-                reinterpret_cast<Operation *>(*getParameter()));
+    results.set(getResult().cast<OpResult>(), getOperation());
   }
   return DiagnosedSilenceableFailure::success();
 }
 
-void mlir::test::TestProduceParamOrForwardOperandOp::getEffects(
+void mlir::test::TestProduceSelfHandleOrForwardOperandOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   if (getOperand())
     transform::onlyReadsHandle(getOperand(), effects);
   transform::producesHandle(getRes(), effects);
 }
 
-LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() {
-  if (getParameter().has_value() ^ (getNumOperands() != 1))
-    return emitOpError() << "expects either a parameter or an operand";
-  return success();
+DiagnosedSilenceableFailure
+mlir::test::TestProduceValueHandleToSelfOperand::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  results.setValues(getOut().cast<OpResult>(), getIn());
+  return DiagnosedSilenceableFailure::success();
+}
+
+void mlir::test::TestProduceValueHandleToSelfOperand::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getIn(), effects);
+  transform::producesHandle(getOut(), effects);
+  transform::onlyReadsPayload(effects);
+}
+
+DiagnosedSilenceableFailure
+mlir::test::TestProduceValueHandleToResult::applyToOne(
+    Operation *target, transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  if (target->getNumResults() <= getNumber())
+    return emitSilenceableError() << "payload has no result #" << getNumber();
+  results.push_back(target->getResult(getNumber()));
+  return DiagnosedSilenceableFailure::success();
+}
+
+void mlir::test::TestProduceValueHandleToResult::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getIn(), effects);
+  transform::producesHandle(getOut(), effects);
+  transform::onlyReadsPayload(effects);
+}
+
+DiagnosedSilenceableFailure
+mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne(
+    Operation *target, transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  if (!target->getBlock())
+    return emitSilenceableError() << "payload has no parent block";
+  if (target->getBlock()->getNumArguments() <= getNumber())
+    return emitSilenceableError()
+           << "parent of the payload has no argument #" << getNumber();
+  results.push_back(target->getBlock()->getArgument(getNumber()));
+  return DiagnosedSilenceableFailure::success();
+}
+
+void mlir::test::TestProduceValueHandleToArgumentOfParentBlock::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getIn(), effects);
+  transform::producesHandle(getOut(), effects);
+  transform::onlyReadsPayload(effects);
 }
 
 DiagnosedSilenceableFailure
@@ -145,23 +189,21 @@ void mlir::test::TestConsumeOperand::getEffects(
   transform::modifiesPayload(effects);
 }
 
-DiagnosedSilenceableFailure
-mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
+DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
   assert(payload.size() == 1 && "expected a single target op");
-  auto value = reinterpret_cast<intptr_t>(payload[0]);
-  if (static_cast<uint64_t>(value) != getParameter()) {
+  if (payload[0]->getName().getStringRef() != getOpKind()) {
     return emitSilenceableError()
-           << "op expected the operand to be associated with " << getParameter()
-           << " got " << value;
+           << "op expected the operand to be associated a payload op of kind "
+           << getOpKind() << " got " << payload[0]->getName().getStringRef();
   }
 
   emitRemark() << "succeeded";
   return DiagnosedSilenceableFailure::success();
 }
 
-void mlir::test::TestConsumeOperandIfMatchesParamOrFail::getEffects(
+void mlir::test::TestConsumeOperandOfOpKindOrFail::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   transform::consumesHandle(getOperand(), effects);
   transform::modifiesPayload(effects);
@@ -182,6 +224,32 @@ void mlir::test::TestPrintRemarkAtOperandOp::getEffects(
   transform::onlyReadsPayload(effects);
 }
 
+DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  ArrayRef<Value> values = state.getPayloadValues(getIn());
+  for (Value value : values) {
+    std::string note;
+    llvm::raw_string_ostream os(note);
+    if (auto arg = value.dyn_cast<BlockArgument>()) {
+      os << "a block argument #" << arg.getArgNumber() << " in block #"
+         << std::distance(arg.getOwner()->getParent()->begin(),
+                          arg.getOwner()->getIterator())
+         << " in region #" << arg.getOwner()->getParent()->getRegionNumber();
+    } else {
+      os << "an op result #" << value.cast<OpResult>().getResultNumber();
+    }
+    InFlightDiagnostic diag = ::emitRemark(value.getLoc()) << getMessage();
+    diag.attachNote() << "value handle points to " << os.str();
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+void mlir::test::TestPrintRemarkAtOperandValue::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getIn(), effects);
+  transform::onlyReadsPayload(effects);
+}
+
 DiagnosedSilenceableFailure
 mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results,
                                           transform::TransformState &state) {
@@ -235,6 +303,7 @@ DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
 void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   transform::onlyReadsHandle(getOperand(), effects);
+  transform::producesHandle(getOut(), effects);
   transform::onlyReadsPayload(effects);
 }
 
@@ -528,6 +597,18 @@ mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results,
   return DiagnosedSilenceableFailure::success();
 }
 
+void mlir::test::TestProduceNullValueOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::producesHandle(getOut(), effects);
+}
+
+DiagnosedSilenceableFailure
+mlir::test::TestProduceNullValueOp::apply(transform::TransformResults &results,
+                                          transform::TransformState &state) {
+  results.setValues(getOut().cast<OpResult>(), Value());
+  return DiagnosedSilenceableFailure::success();
+}
+
 void mlir::test::TestRequiredMemoryEffectsOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   if (getHasOperandEffect())
index cc67c2a..4c9b3d5 100644 (file)
@@ -39,37 +39,79 @@ def TestTransformTestDialectParamType
   let assemblyFormat = "";
 }
 
-def TestProduceParamOrForwardOperandOp
-  : Op<Transform_Dialect, "test_produce_param_or_forward_operand",
+def TestProduceSelfHandleOrForwardOperandOp
+  : Op<Transform_Dialect, "test_produce_self_handle_or_forward_operand",
        [DeclareOpInterfaceMethods<TransformOpInterface>,
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
-  let arguments = (ins
-     Optional<PDL_Operation>:$operand,
-     OptionalAttr<I64Attr>:$parameter);
+  let arguments = (ins Optional<PDL_Operation>:$operand);
   let results = (outs PDL_Operation:$res);
-  let assemblyFormat = "(`from` $operand^)? ($parameter^)? attr-dict";
+  let assemblyFormat = "($operand^)? attr-dict";
   let cppNamespace = "::mlir::test";
-  let hasVerifier = 1;
+}
+
+def TestProduceValueHandleToSelfOperand
+  : Op<Transform_Dialect, "test_produce_value_handle_to_self_operand",
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let arguments = (ins TransformHandleTypeInterface:$in);
+  let results = (outs TransformValueHandleTypeInterface:$out);
+  let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)";
+  let cppNamespace = "::mlir::test";
+  
+}
+
+def TestProduceValueHandleToResult
+  : Op<Transform_Dialect, "test_produce_value_handle_to_result",
+       [TransformEachOpTrait, TransformOpInterface,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let arguments = (ins TransformHandleTypeInterface:$in,
+                       I64Attr:$number);
+  let results = (outs TransformValueHandleTypeInterface:$out);
+  let assemblyFormat = "$in `,` $number attr-dict `:` functional-type(operands, results)";
+  let cppNamespace = "::mlir::test";
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+    
+def TestProduceValueHandleToArgumentOfParentBlock
+  : Op<Transform_Dialect, "test_produce_value_handle_to_argument_of_parent_block",
+       [TransformEachOpTrait, TransformOpInterface,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let arguments = (ins TransformHandleTypeInterface:$in,
+                       I64Attr:$number);
+  let results = (outs TransformValueHandleTypeInterface:$out);
+  let assemblyFormat = "$in `,` $number attr-dict `:` functional-type(operands, results)";
+  let cppNamespace = "::mlir::test";
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
 }
 
 def TestConsumeOperand : Op<Transform_Dialect, "test_consume_operand",
      [DeclareOpInterfaceMethods<TransformOpInterface>,
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let arguments = (ins
-    PDL_Operation:$operand,
+    Transform_AnyHandleOrParamType:$operand,
     Optional<PDL_Operation>:$second_operand);
-  let assemblyFormat = "$operand (`,` $second_operand^)? attr-dict";
+  let assemblyFormat = "$operand (`,` $second_operand^)? attr-dict `:` type($operand)";
   let cppNamespace = "::mlir::test";
 }
 
-def TestConsumeOperandIfMatchesParamOrFail
-  : Op<Transform_Dialect, "test_consume_operand_if_matches_param_or_fail",
+def TestConsumeOperandOfOpKindOrFail
+  : Op<Transform_Dialect, "test_consume_operand_of_op_kind_or_fail",
        [DeclareOpInterfaceMethods<TransformOpInterface>,
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let arguments = (ins
     PDL_Operation:$operand,
-    I64Attr:$parameter);
-  let assemblyFormat = "$operand `[` $parameter `]` attr-dict";
+    StrAttr:$op_kind);
+  let assemblyFormat = "$operand `,` $op_kind attr-dict";
   let cppNamespace = "::mlir::test";
 }
 
@@ -85,6 +127,16 @@ def TestPrintRemarkAtOperandOp
   let cppNamespace = "::mlir::test";
 }
 
+def TestPrintRemarkAtOperandValue
+  : Op<Transform_Dialect, "test_print_remark_at_operand_value",
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let arguments = (ins TransformValueHandleTypeInterface:$in,
+                       StrAttr:$message);
+  let assemblyFormat = "$in `,` $message attr-dict `:` type($in)";
+  let cppNamespace = "::mlir::test";
+}
+
 def TestAddTestExtensionOp
   : Op<Transform_Dialect, "test_add_test_extension",
        [DeclareOpInterfaceMethods<TransformOpInterface>,
@@ -107,8 +159,9 @@ def TestRemapOperandPayloadToSelfOp
   : Op<Transform_Dialect, "test_remap_operand_to_self",
        [DeclareOpInterfaceMethods<TransformOpInterface>,
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
-  let arguments = (ins  PDL_Operation:$operand);
-  let assemblyFormat = "$operand attr-dict";
+  let arguments = (ins PDL_Operation:$operand);
+  let results = (outs Optional<TransformHandleTypeInterface>:$out);        
+  let assemblyFormat = "$operand attr-dict (`:` type($out)^)?";
   let cppNamespace = "::mlir::test";
 }
 
@@ -349,6 +402,15 @@ def TestProduceNullParamOp
   let cppNamespace = "::mlir::test";
 }
 
+def TestProduceNullValueOp
+  : Op<Transform_Dialect, "test_produce_null_value",
+       [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+        DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let results = (outs TransformValueHandleTypeInterface:$out);
+  let assemblyFormat = "attr-dict `:` type($out)";
+  let cppNamespace = "::mlir::test";
+}
+
 def TestRequiredMemoryEffectsOp
   : Op<Transform_Dialect, "test_required_memory_effects",
       [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
index f9823e1..7beae91 100644 (file)
@@ -46,65 +46,93 @@ public:
     return "apply transform dialect operations one by one";
   }
 
-  ArrayRef<transform::MappedValue>
-  findOperationsByName(Operation *root, StringRef name,
-                       SmallVectorImpl<transform::MappedValue> &storage) {
-    size_t start = storage.size();
+  void findOperationsByName(Operation *root, StringRef name,
+                            SmallVectorImpl<Operation *> &operations) {
     root->walk([&](Operation *op) {
       if (op->getName().getStringRef() == name) {
-        storage.push_back(op);
+        operations.push_back(op);
       }
     });
-    return ArrayRef(storage).drop_front(start);
   }
 
-  ArrayRef<transform::MappedValue>
-  createParameterMapping(MLIRContext &context, ArrayRef<int> values,
-                         SmallVectorImpl<transform::MappedValue> &storage) {
-    size_t start = storage.size();
-    llvm::append_range(storage, llvm::map_range(values, [&](int v) {
-                         Builder b(&context);
-                         return transform::MappedValue(b.getI64IntegerAttr(v));
-                       }));
-    return ArrayRef(storage).drop_front(start);
+  void createParameterMapping(MLIRContext &context, ArrayRef<int> values,
+                              RaggedArray<transform::MappedValue> &result) {
+    SmallVector<transform::MappedValue> storage =
+        llvm::to_vector(llvm::map_range(values, [&](int v) {
+          Builder b(&context);
+          return transform::MappedValue(b.getI64IntegerAttr(v));
+        }));
+    result.push_back(std::move(storage));
+  }
+
+  void
+  createOpResultMapping(Operation *root, StringRef name,
+                        RaggedArray<transform::MappedValue> &extraMapping) {
+    SmallVector<Operation *> operations;
+    findOperationsByName(root, name, operations);
+    SmallVector<Value> results;
+    for (Operation *op : operations)
+      llvm::append_range(results, op->getResults());
+    extraMapping.push_back(results);
+  }
+
+  unsigned numberOfSetOptions(const Option<std::string> &ops,
+                              const ListOption<int> &params,
+                              const Option<std::string> &values) {
+    unsigned numSetValues = 0;
+    numSetValues += !ops.empty();
+    numSetValues += !params.empty();
+    numSetValues += !values.empty();
+    return numSetValues;
   }
 
   void runOnOperation() override {
-    if (!bindFirstExtraToOps.empty() && !bindFirstExtraToParams.empty()) {
-      emitError(UnknownLoc::get(&getContext()))
-          << "cannot bind the first extra top-level argument to both "
-             "operations and parameters";
+    unsigned firstSetOptions =
+        numberOfSetOptions(bindFirstExtraToOps, bindFirstExtraToParams,
+                           bindFirstExtraToResultsOfOps);
+    unsigned secondSetOptions =
+        numberOfSetOptions(bindSecondExtraToOps, bindSecondExtraToParams,
+                           bindSecondExtraToResultsOfOps);
+    auto loc = UnknownLoc::get(&getContext());
+    if (firstSetOptions > 1) {
+      emitError(loc) << "cannot bind the first extra top-level argument to "
+                        "multiple entities";
       return signalPassFailure();
     }
-    if (!bindSecondExtraToOps.empty() && !bindSecondExtraToParams.empty()) {
-      emitError(UnknownLoc::get(&getContext()))
-          << "cannot bind the second extra top-level argument to both "
-             "operations and parameters";
+    if (secondSetOptions > 1) {
+      emitError(loc) << "cannot bind the second extra top-level argument to "
+                        "multiple entities";
       return signalPassFailure();
     }
-    if ((!bindSecondExtraToOps.empty() || !bindSecondExtraToParams.empty()) &&
-        bindFirstExtraToOps.empty() && bindFirstExtraToParams.empty()) {
-      emitError(UnknownLoc::get(&getContext()))
-          << "cannot bind the second extra top-level argument without binding "
-             "the first";
-      return signalPassFailure();
+    if (firstSetOptions == 0 && secondSetOptions != 0) {
+      emitError(loc) << "cannot bind the second extra top-level argument "
+                        "without bindings the first";
     }
 
-    SmallVector<transform::MappedValue> extraMappingStorage;
-    SmallVector<ArrayRef<transform::MappedValue>> extraMapping;
+    RaggedArray<transform::MappedValue> extraMapping;
     if (!bindFirstExtraToOps.empty()) {
-      extraMapping.push_back(findOperationsByName(
-          getOperation(), bindFirstExtraToOps.getValue(), extraMappingStorage));
+      SmallVector<Operation *> operations;
+      findOperationsByName(getOperation(), bindFirstExtraToOps.getValue(),
+                           operations);
+      extraMapping.push_back(operations);
     } else if (!bindFirstExtraToParams.empty()) {
-      extraMapping.push_back(createParameterMapping(
-          getContext(), bindFirstExtraToParams, extraMappingStorage));
+      createParameterMapping(getContext(), bindFirstExtraToParams,
+                             extraMapping);
+    } else if (!bindFirstExtraToResultsOfOps.empty()) {
+      createOpResultMapping(getOperation(), bindFirstExtraToResultsOfOps,
+                            extraMapping);
     }
+
     if (!bindSecondExtraToOps.empty()) {
-      extraMapping.push_back(findOperationsByName(
-          getOperation(), bindSecondExtraToOps, extraMappingStorage));
+      SmallVector<Operation *> operations;
+      findOperationsByName(getOperation(), bindSecondExtraToOps, operations);
+      extraMapping.push_back(operations);
     } else if (!bindSecondExtraToParams.empty()) {
-      extraMapping.push_back(createParameterMapping(
-          getContext(), bindSecondExtraToParams, extraMappingStorage));
+      createParameterMapping(getContext(), bindSecondExtraToParams,
+                             extraMapping);
+    } else if (!bindSecondExtraToResultsOfOps.empty()) {
+      createOpResultMapping(getOperation(), bindSecondExtraToResultsOfOps,
+                            extraMapping);
     }
 
     options = options.enableExpensiveChecks(enableExpensiveChecks);
@@ -128,6 +156,10 @@ public:
       *this, "bind-first-extra-to-params",
       llvm::cl::desc("bind the first extra argument of the top-level op to "
                      "the given integer parameters")};
+  Option<std::string> bindFirstExtraToResultsOfOps{
+      *this, "bind-first-extra-to-results-of-ops",
+      llvm::cl::desc("bind the first extra argument of the top-level op to "
+                     "results of payload operations of the given kind")};
 
   Option<std::string> bindSecondExtraToOps{
       *this, "bind-second-extra-to-ops",
@@ -137,6 +169,11 @@ public:
       *this, "bind-second-extra-to-params",
       llvm::cl::desc("bind the second extra argument of the top-level op to "
                      "the given integer parameters")};
+  Option<std::string> bindSecondExtraToResultsOfOps{
+      *this, "bind-second-extra-to-results-of-ops",
+      llvm::cl::desc("bind the second extra argument of the top-level op to "
+                     "results of payload operations of the given kind")};
+
   Option<std::string> transformFileName{
       *this, "transform-file-name", llvm::cl::init(""),
       llvm::cl::desc(