[mlir] more side effect verification in transform dialect
authorAlex Zinenko <zinenko@google.com>
Mon, 23 Jan 2023 14:46:46 +0000 (14:46 +0000)
committerAlex Zinenko <zinenko@google.com>
Mon, 6 Feb 2023 13:15:36 +0000 (13:15 +0000)
Add a verifier checking that if a transform operation consumes a handle
(which is associated with a payload operation being erased or
recreated), it also indicates modification of the payload IR. This
hasn't been consistent in the past because of the "no-aliasing"
assumption where we couldn't have had more than one handle to an
operation, requiring some handle-manipulation operations, such as
`transform.merge_handles` to consume their operands. That assumption has
been liften and it is no longer necessary for these operations to
consume handles and thus make the life harder for the clients.

Additionally, remove TransformEffects.td that uses the ODS mechanism for
indicating side effects that works only for operands and results. It
was being used incorrectly to also indicate effects on the payload IR,
not assocaited with any IR value, and lacked the consume/produce
semantics available via helpers in C++.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D142361

16 files changed:
mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
mlir/include/mlir/Dialect/Transform/IR/TransformEffects.td [deleted file]
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/ops-invalid.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

index 1ab7bd3..76f0d8e 100644 (file)
@@ -11,7 +11,6 @@
 
 include "mlir/Dialect/PDL/IR/PDLTypes.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
index 2431488..3a66391 100644 (file)
@@ -11,7 +11,6 @@
 
 include "mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Dialect/PDL/IR/PDLTypes.td"
index 802a915..2ee1354 100644 (file)
@@ -10,7 +10,6 @@
 #define GPU_TRANSFORM_OPS
 
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/PDL/IR/PDLTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
index 630000c..5e776e0 100644 (file)
@@ -10,7 +10,6 @@
 #define LINALG_TRANSFORM_OPS
 
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Dialect/PDL/IR/PDLTypes.td"
@@ -89,7 +88,8 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
 
 def FuseIntoContainingOp :
     Op<Transform_Dialect, "structured.fuse_into_containing_op",
-      [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+      [DeclareOpInterfaceMethods<TransformOpInterface>,
+       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let summary = "Fuse a producer into a containing operation.";
 
   let description = [{
@@ -125,14 +125,9 @@ def FuseIntoContainingOp :
     This operation reads the containing op handle.
   }];
 
-  let arguments = (ins Arg<PDL_Operation, "",
-                           [TransformMappingRead,
-                            TransformMappingFree]>:$producer_op,
-                       Arg<PDL_Operation, "",
-                           [TransformMappingRead]>:$containing_op);
-  let results = (outs Res<PDL_Operation, "",
-                          [TransformMappingAlloc,
-                           TransformMappingWrite]>:$fused_op);
+  let arguments = (ins PDL_Operation:$producer_op,
+                       PDL_Operation:$containing_op);
+  let results = (outs PDL_Operation:$fused_op);
   let assemblyFormat = "$producer_op `into` $containing_op attr-dict";
 
   let builders = [
index c480e7c..f16fe8a 100644 (file)
@@ -10,7 +10,6 @@
 #define MEMREF_TRANSFORM_OPS
 
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/PDL/IR/PDLTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
index affa9ab..b286850 100644 (file)
@@ -10,7 +10,6 @@
 #define SCF_TRANSFORM_OPS
 
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformEffects.td b/mlir/include/mlir/Dialect/Transform/IR/TransformEffects.td
deleted file mode 100644 (file)
index b6106fe..0000000
+++ /dev/null
@@ -1,62 +0,0 @@
-
-//===- TransformEffect.td - Transform side effects ---------*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines side effects and associated resources for operations in the
-// Transform dialect and extensions.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_EFFECTS_TD
-#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_EFFECTS_TD
-
-include "mlir/Interfaces/SideEffectInterfaces.td"
-
-//===----------------------------------------------------------------------===//
-// Effects on the mapping between Transform IR values and Payload IR ops.
-//===----------------------------------------------------------------------===//
-
-// Side effect resource corresponding to the mapping between transform IR values
-// and Payload IR operations.
-def TransformMappingResource
-    : Resource<"::mlir::transform::TransformMappingResource">;
-
-// Describes the creation of a new entry in the transform mapping. Should be
-// accompanied by the Write effect as the entry is immediately initialized by
-// any reasonable transform operation.
-def TransformMappingAlloc : MemAlloc<TransformMappingResource>;
-
-// Describes the removal of an entry in the transform mapping. Typically
-// accompanied by the Read effect.
-def TransformMappingFree : MemFree<TransformMappingResource>;
-
-// Describes the access to the mapping. Read-only accesses can be reordered.
-def TransformMappingRead : MemRead<TransformMappingResource>;
-
-// Describes a modification of an existing entry in the mapping. It is rarely
-// used alone, and is mostly accompanied by the Allocate effect.
-def TransformMappingWrite : MemWrite<TransformMappingResource>;
-
-//===----------------------------------------------------------------------===//
-// Effects on Payload IR.
-//===----------------------------------------------------------------------===//
-
-// Side effect resource corresponding to the Payload IR itself.
-def PayloadIRResource : Resource<"::mlir::transform::PayloadIRResource">;
-
-// Corresponds to the read-only access to the Payload IR through some operation
-// handles in the Transform IR.
-def PayloadIRRead : MemRead<PayloadIRResource>;
-
-// Corresponds to the mutation of the Payload IR through an operation handle in
-// the Transform IR. Should be accompanied by the Read effect for most transform
-// operations (only a complete overwrite of the root op of the Payload IR is a
-// write-only modification).
-def PayloadIRWrite : MemWrite<PayloadIRResource>;
-
-#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_EFFECTS_TD
index 6f3b4cf..dd66e61 100644 (file)
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Dialect/Transform/IR/TransformAttrs.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 
 def AlternativesOp : TransformDialectOp<"alternatives",
@@ -466,7 +466,8 @@ def SequenceOp : TransformDialectOp<"sequence",
 
 def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns",
     [DeclareOpInterfaceMethods<TransformOpInterface>, NoTerminator,
-     OpAsmOpInterface, PossibleTopLevelTransformOpTrait, RecursiveMemoryEffects,
+     OpAsmOpInterface, PossibleTopLevelTransformOpTrait,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      SymbolTable]> {
   let summary = "Contains PDL patterns available for use in transforms";
   let description = [{
@@ -505,8 +506,8 @@ def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns",
   }];
 
   let arguments = (ins
-    Arg<Optional<TransformHandleTypeInterface>, "Root operation of the Payload IR",
-        [TransformMappingRead]>:$root);
+    Arg<Optional<TransformHandleTypeInterface>, "Root operation of the Payload IR"
+        >:$root);
   let regions = (region SizedRegion<1>:$body);
   let assemblyFormat = "($root^ `:` type($root))? attr-dict-with-keyword regions";
 
@@ -518,7 +519,8 @@ def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns",
   }];
 }
 
-def YieldOp : TransformDialectOp<"yield", [Terminator]> {
+def YieldOp : TransformDialectOp<"yield",
+    [Terminator, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let summary = "Yields operation handles from a transform IR region";
   let description = [{
     This terminator operation yields operation handles from regions of the
@@ -527,8 +529,8 @@ def YieldOp : TransformDialectOp<"yield", [Terminator]> {
   }];
 
   let arguments = (ins
-    Arg<Variadic<TransformHandleTypeInterface>, "Operation handles yielded back to the parent",
-        [TransformMappingRead]>:$operands);
+    Arg<Variadic<TransformHandleTypeInterface>, "Operation handles yielded back to the parent"
+        >:$operands);
   let assemblyFormat = "operands attr-dict (`:` type($operands)^)?";
 
   let builders = [
index 060e6bc..4533c5a 100644 (file)
@@ -10,7 +10,6 @@
 #define VECTOR_TRANSFORM_OPS
 
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/PDL/IR/PDLTypes.td"
 include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
index a1b2d48..a2f8a1d 100644 (file)
@@ -60,13 +60,14 @@ transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
 
 void transform::OneShotBufferizeOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
-                       TransformMappingResource::get());
-
   // Handles that are not modules are not longer usable.
-  if (!getTargetIsModule())
-    effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
-                         TransformMappingResource::get());
+  if (!getTargetIsModule()) {
+    consumesHandle(getTarget(), effects);
+  } else {
+    onlyReadsHandle(getTarget(), effects);
+  }
+
+  modifiesPayload(effects);
 }
 
 //===----------------------------------------------------------------------===//
index 022c94e..94725fa 100644 (file)
@@ -713,6 +713,14 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
   return DiagnosedSilenceableFailure::success();
 }
 
+void transform::FuseIntoContainingOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getProducerOp(), effects);
+  onlyReadsHandle(getContainingOp(), effects);
+  producesHandle(getFusedOp(), effects);
+  modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // GeneralizeOp
 //===----------------------------------------------------------------------===//
@@ -2668,6 +2676,7 @@ void transform::TileToForeachThreadOp::getEffects(
   onlyReadsHandle(getPackedNumThreads(), effects);
   onlyReadsHandle(getPackedTileSizes(), effects);
   producesHandle(getResults(), effects);
+  modifiesPayload(effects);
 }
 
 SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedNumThreads() {
@@ -2997,6 +3006,7 @@ void transform::MaskedVectorizeOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   consumesHandle(getTarget(), effects);
   onlyReadsHandle(getVectorSizes(), effects);
+  modifiesPayload(effects);
 }
 
 SmallVector<OpFoldResult> MaskedVectorizeOp::getMixedVectorSizes() {
index 5ecc1f4..e14fca2 100644 (file)
@@ -783,6 +783,7 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
         });
   };
 
+  std::optional<unsigned> firstConsumedOperand = std::nullopt;
   for (OpOperand &operand : op->getOpOperands()) {
     auto range = effectsOn(operand.get());
     if (range.empty()) {
@@ -793,7 +794,30 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
                         << operand.getOperandNumber();
       return diag;
     }
+    if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
+      InFlightDiagnostic diag = op->emitError()
+                                << "TransformOpInterface did not expect "
+                                   "'allocate' memory effect on an operand";
+      diag.attachNote() << "specified for operand #"
+                        << operand.getOperandNumber();
+      return diag;
+    }
+    if (!firstConsumedOperand &&
+        ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
+      firstConsumedOperand = operand.getOperandNumber();
+    }
+  }
+
+  if (firstConsumedOperand &&
+      !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
+    InFlightDiagnostic diag =
+        op->emitError()
+        << "TransformOpInterface expects ops consuming operands to have a "
+           "'write' effect on the payload resource";
+    diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
+    return diag;
   }
+
   for (OpResult result : op->getResults()) {
     auto range = effectsOn(result);
     if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
@@ -806,6 +830,7 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
       return diag;
     }
   }
+
   return success();
 }
 
index f9bdc0e..bc1ac6d 100644 (file)
@@ -292,7 +292,7 @@ transform::CastOp::applyToOne(Operation *target, ApplyToEachResultList &results,
 void transform::CastOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   onlyReadsPayload(effects);
-  consumesHandle(getInput(), effects);
+  onlyReadsHandle(getInput(), effects);
   producesHandle(getOutput(), effects);
 }
 
@@ -501,7 +501,7 @@ bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
 
 void transform::MergeHandlesOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  consumesHandle(getHandles(), effects);
+  onlyReadsHandle(getHandles(), effects);
   producesHandle(getResult(), effects);
 
   // There are no effects on the Payload IR as this is only a handle
@@ -557,7 +557,7 @@ transform::SplitHandlesOp::apply(transform::TransformResults &results,
 
 void transform::SplitHandlesOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  consumesHandle(getHandle(), effects);
+  onlyReadsHandle(getHandle(), effects);
   producesHandle(getResults(), effects);
   // There are no effects on the Payload IR as this is only a handle
   // manipulation.
@@ -626,7 +626,7 @@ transform::ReplicateOp::apply(transform::TransformResults &results,
 void transform::ReplicateOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   onlyReadsHandle(getPattern(), effects);
-  consumesHandle(getHandles(), effects);
+  onlyReadsHandle(getHandles(), effects);
   producesHandle(getReplicated(), effects);
 }
 
@@ -832,34 +832,62 @@ remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
     effects.emplace_back(effect.getEffect(), target, effect.getResource());
 }
 
-void transform::SequenceOp::getEffects(
-    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  onlyReadsHandle(getRoot(), effects);
-  onlyReadsHandle(getExtraBindings(), effects);
-  producesHandle(getResults(), effects);
+namespace {
+template <typename T>
+using has_get_extra_bindings = decltype(std::declval<T &>().getExtraBindings());
+} // namespace
+
+/// Populate `effects` with transform dialect memory effects for the potential
+/// top-level operation. Such operations have recursive effects from nested
+/// operations. When they have an operand, we can additionally remap effects on
+/// the block argument to be effects on the operand.
+template <typename OpTy>
+static void getPotentialTopLevelEffects(
+    OpTy operation, SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(operation->getOperands(), effects);
+  transform::producesHandle(operation->getResults(), effects);
+
+  if (!operation.getRoot()) {
+    for (Operation &op : *operation.getBodyBlock()) {
+      auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
+      if (!iface)
+        continue;
 
-  if (!getRoot()) {
-    for (Operation &op : *getBodyBlock()) {
-      auto iface = cast<MemoryEffectOpInterface>(&op);
       SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
       iface.getEffects(effects);
     }
     return;
   }
 
-  // Carry over all effects on the argument of the entry block as those on the
-  // operand, this is the same value just remapped.
-  for (Operation &op : *getBodyBlock()) {
-    auto iface = cast<MemoryEffectOpInterface>(&op);
+  // Carry over all effects on arguments of the entry block as those on the
+  // operands, this is the same value just remapped.
+  for (Operation &op : *operation.getBodyBlock()) {
+    auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
+    if (!iface)
+      continue;
 
-    remapEffects(iface, getBodyBlock()->getArgument(0), getRoot(), effects);
-    for (auto [source, target] : llvm::zip(
-             getBodyBlock()->getArguments().drop_front(), getExtraBindings())) {
-      remapEffects(iface, source, target, effects);
+    remapEffects(iface, operation.getBodyBlock()->getArgument(0),
+                 operation.getRoot(), effects);
+    if constexpr (llvm::is_detected<has_get_extra_bindings, OpTy>::value) {
+      for (auto [source, target] :
+           llvm::zip(operation.getBodyBlock()->getArguments().drop_front(),
+                     operation.getExtraBindings())) {
+        remapEffects(iface, source, target, effects);
+      }
     }
+
+    SmallVector<MemoryEffects::EffectInstance> nestedEffects;
+    iface.getEffectsOnResource(transform::PayloadIRResource::get(),
+                               nestedEffects);
+    llvm::append_range(effects, nestedEffects);
   }
 }
 
+void transform::SequenceOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  getPotentialTopLevelEffects(*this, effects);
+}
+
 OperandRange transform::SequenceOp::getSuccessorEntryOperands(
     std::optional<unsigned> index) {
   assert(index && *index == 0 && "unexpected region index");
@@ -983,6 +1011,11 @@ transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
   return state.applyTransform(transformOp);
 }
 
+void transform::WithPDLPatternsOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  getPotentialTopLevelEffects(*this, effects);
+}
+
 LogicalResult transform::WithPDLPatternsOp::verify() {
   Block *body = getBodyBlock();
   Operation *topLevelOp = nullptr;
@@ -1065,3 +1098,12 @@ void transform::PrintOp::getEffects(
   // writes into the default resource.
   effects.emplace_back(MemoryEffects::Write::get());
 }
+
+//===----------------------------------------------------------------------===//
+// YieldOp
+//===----------------------------------------------------------------------===//
+
+void transform::YieldOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getOperands(), effects);
+}
index 2fd0a37..500fe61 100644 (file)
@@ -251,7 +251,7 @@ transform.sequence failures(suppress) {
 ^bb0(%arg0: !transform.any_op):
   // expected-error @below {{TransformOpInterface requires memory effects on operands to be specified}}
   // expected-note @below {{no effects specified for operand #0}}
-  transform.test_required_memory_effects %arg0 : (!transform.any_op) -> !transform.any_op
+  transform.test_required_memory_effects %arg0 {modifies_payload} : (!transform.any_op) -> !transform.any_op
 }
 
 // -----
@@ -260,5 +260,5 @@ transform.sequence failures(suppress) {
 ^bb0(%arg0: !transform.any_op):
   // expected-error @below {{TransformOpInterface requires 'allocate' memory effect to be specified for results}}
   // expected-note @below {{no 'allocate' effect specified for result #0}}
-  transform.test_required_memory_effects %arg0 {has_operand_effect} : (!transform.any_op) -> !transform.any_op
+  transform.test_required_memory_effects %arg0 {has_operand_effect, modifies_payload} : (!transform.any_op) -> !transform.any_op
 }
index 63d6828..0bd3031 100644 (file)
@@ -118,6 +118,13 @@ mlir::test::TestProduceParamOrForwardOperandOp::apply(
   return DiagnosedSilenceableFailure::success();
 }
 
+void mlir::test::TestProduceParamOrForwardOperandOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  if (getOperand())
+    transform::onlyReadsHandle(getOperand(), effects);
+  transform::producesHandle(getRes(), effects);
+}
+
 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() {
   if (getParameter().has_value() ^ (getNumOperands() != 1))
     return emitOpError() << "expects either a parameter or an operand";
@@ -130,6 +137,14 @@ mlir::test::TestConsumeOperand::apply(transform::TransformResults &results,
   return DiagnosedSilenceableFailure::success();
 }
 
+void mlir::test::TestConsumeOperand::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::consumesHandle(getOperand(), effects);
+  if (getSecondOperand())
+    transform::consumesHandle(getSecondOperand(), effects);
+  transform::modifiesPayload(effects);
+}
+
 DiagnosedSilenceableFailure
 mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
     transform::TransformResults &results, transform::TransformState &state) {
@@ -146,6 +161,12 @@ mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
   return DiagnosedSilenceableFailure::success();
 }
 
+void mlir::test::TestConsumeOperandIfMatchesParamOrFail::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::consumesHandle(getOperand(), effects);
+  transform::modifiesPayload(effects);
+}
+
 DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
@@ -155,6 +176,12 @@ DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply(
   return DiagnosedSilenceableFailure::success();
 }
 
+void mlir::test::TestPrintRemarkAtOperandOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getOperand(), effects);
+  transform::onlyReadsPayload(effects);
+}
+
 DiagnosedSilenceableFailure
 mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results,
                                           transform::TransformState &state) {
@@ -187,6 +214,12 @@ mlir::test::TestCheckIfTestExtensionPresentOp::apply(
   return DiagnosedSilenceableFailure::success();
 }
 
+void mlir::test::TestCheckIfTestExtensionPresentOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getOperand(), effects);
+  transform::onlyReadsPayload(effects);
+}
+
 DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   auto *extension = state.getExtension<TestTransformStateExtension>();
@@ -199,6 +232,12 @@ DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
   return DiagnosedSilenceableFailure::success();
 }
 
+void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getOperand(), effects);
+  transform::onlyReadsPayload(effects);
+}
+
 DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   state.removeExtension<TestTransformStateExtension>();
@@ -312,6 +351,13 @@ mlir::test::TestCopyPayloadOp::apply(transform::TransformResults &results,
   return DiagnosedSilenceableFailure::success();
 }
 
+void mlir::test::TestCopyPayloadOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getHandle(), effects);
+  transform::producesHandle(getCopy(), effects);
+  transform::onlyReadsPayload(effects);
+}
+
 DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload(
     Location loc, ArrayRef<Operation *> payload) const {
   if (payload.empty())
@@ -491,6 +537,9 @@ void mlir::test::TestRequiredMemoryEffectsOp::getEffects(
     transform::producesHandle(getOut(), effects);
   else
     transform::onlyReadsHandle(getOut(), effects);
+
+  if (getModifiesPayload())
+    transform::modifiesPayload(effects);
 }
 
 DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply(
index 02e8a69..cc67c2a 100644 (file)
 #ifndef MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
 #define MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
 
+include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/OpBase.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/PDL/IR/PDLTypes.td"
 
@@ -41,35 +41,33 @@ def TestTransformTestDialectParamType
 
 def TestProduceParamOrForwardOperandOp
   : Op<Transform_Dialect, "test_produce_param_or_forward_operand",
-       [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let arguments = (ins
-     Arg<Optional<PDL_Operation>, "", [TransformMappingRead]>:$operand,
+     Optional<PDL_Operation>:$operand,
      OptionalAttr<I64Attr>:$parameter);
-  let results = (outs
-     Res<PDL_Operation, "",
-         [TransformMappingAlloc, TransformMappingWrite]>:$res);
+  let results = (outs PDL_Operation:$res);
   let assemblyFormat = "(`from` $operand^)? ($parameter^)? attr-dict";
   let cppNamespace = "::mlir::test";
   let hasVerifier = 1;
 }
 
 def TestConsumeOperand : Op<Transform_Dialect, "test_consume_operand",
-     [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+     [DeclareOpInterfaceMethods<TransformOpInterface>,
+      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let arguments = (ins
-    Arg<PDL_Operation, "",
-        [TransformMappingRead, TransformMappingFree]>:$operand,
-    Arg<Optional<PDL_Operation>, "",
-        [TransformMappingRead, TransformMappingFree]>:$second_operand);
+    PDL_Operation:$operand,
+    Optional<PDL_Operation>:$second_operand);
   let assemblyFormat = "$operand (`,` $second_operand^)? attr-dict";
   let cppNamespace = "::mlir::test";
 }
 
 def TestConsumeOperandIfMatchesParamOrFail
   : Op<Transform_Dialect, "test_consume_operand_if_matches_param_or_fail",
-       [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let arguments = (ins
-    Arg<PDL_Operation, "",
-        [TransformMappingRead, TransformMappingFree]>:$operand,
+    PDL_Operation:$operand,
     I64Attr:$parameter);
   let assemblyFormat = "$operand `[` $parameter `]` attr-dict";
   let cppNamespace = "::mlir::test";
@@ -77,10 +75,10 @@ def TestConsumeOperandIfMatchesParamOrFail
 
 def TestPrintRemarkAtOperandOp
   : Op<Transform_Dialect, "test_print_remark_at_operand",
-       [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let arguments = (ins
-    Arg<TransformHandleTypeInterface, "",
-        [TransformMappingRead, PayloadIRRead]>:$operand,
+    TransformHandleTypeInterface:$operand,
     StrAttr:$message);
   let assemblyFormat =
     "$operand `,` $message attr-dict `:` type($operand)";
@@ -98,19 +96,18 @@ def TestAddTestExtensionOp
 
 def TestCheckIfTestExtensionPresentOp
   : Op<Transform_Dialect, "test_check_if_test_extension_present",
-       [DeclareOpInterfaceMethods<TransformOpInterface>]> {
-  let arguments = (ins
-    Arg<PDL_Operation, "", [TransformMappingRead, PayloadIRRead]>:$operand);
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let arguments = (ins PDL_Operation:$operand);
   let assemblyFormat = "$operand attr-dict";
   let cppNamespace = "::mlir::test";
 }
 
 def TestRemapOperandPayloadToSelfOp
   : Op<Transform_Dialect, "test_remap_operand_to_self",
-       [DeclareOpInterfaceMethods<TransformOpInterface>]> {
-  let arguments = (ins
-    Arg<PDL_Operation, "",
-        [TransformMappingRead, TransformMappingWrite, PayloadIRRead]>:$operand);
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let arguments = (ins  PDL_Operation:$operand);
   let assemblyFormat = "$operand attr-dict";
   let cppNamespace = "::mlir::test";
 }
@@ -255,10 +252,10 @@ def TestPrintNumberOfAssociatedPayloadIROps
 
 def TestCopyPayloadOp
   : Op<Transform_Dialect, "test_copy_payload",
-       [DeclareOpInterfaceMethods<TransformOpInterface>]> {
-  let arguments = (ins Arg<PDL_Operation, "", [TransformMappingRead]>:$handle);
-  let results = (outs Res<PDL_Operation, "",
-      [TransformMappingAlloc, TransformMappingWrite]>:$copy);
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let arguments = (ins PDL_Operation:$handle);
+  let results = (outs PDL_Operation:$copy);
   let cppNamespace = "::mlir::test";
   let assemblyFormat = "$handle attr-dict";
 }
@@ -358,7 +355,8 @@ def TestRequiredMemoryEffectsOp
        DeclareOpInterfaceMethods<TransformOpInterface>]> {
   let arguments = (ins TransformHandleTypeInterface:$in,
                        UnitAttr:$has_operand_effect,
-                       UnitAttr:$has_result_effect);
+                       UnitAttr:$has_result_effect,
+                       UnitAttr:$modifies_payload);
   let results = (outs TransformHandleTypeInterface:$out);
   let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)";
   let cppNamespace = "::mlir::test";