[mlir] add readonly/consume annotations to transform named sequences
authorAlex Zinenko <zinenko@google.com>
Thu, 30 Mar 2023 12:31:48 +0000 (12:31 +0000)
committerAlex Zinenko <zinenko@google.com>
Tue, 4 Apr 2023 09:38:00 +0000 (09:38 +0000)
Use the argument attribute mechanism for function-like operations to
annotate the arguments of named transform sequences as consuming or only
reading the handles passed as arguments. This makes it possible to
correctly specify handle invalidation for external named sequences by
requiring their declarations to always provide such annotations.
Additionally, these annotations remove the need to analyze the body of
a named sequence to understand its effects on the arguments. Make them
required for named sequences that are called from the same file, in
addition to external sequences.

Provide a convenience pass that infers annotations by analyzing bodies
of named sequences provided they are not called from the same file.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D147223

17 files changed:
mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/include/mlir/Dialect/Transform/Transforms/Passes.h
mlir/include/mlir/Dialect/Transform/Transforms/Passes.td
mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
mlir/lib/Dialect/Transform/Transforms/InferEffects.cpp [new file with mode: 0644]
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
mlir/test/Dialect/Transform/infer-effects.mlir [new file with mode: 0644]
mlir/test/Dialect/Transform/ops-invalid.mlir
mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
mlir/test/Dialect/Transform/test-interpreter.mlir

index 639a7c7..f034f3a 100644 (file)
@@ -36,6 +36,13 @@ def Transform_Dialect : Dialect {
       constexpr const static llvm::StringLiteral
           kTargetTagAttrName = "transform.target_tag";
 
+      /// Names of the attributes indicating whether an argument of an external
+      /// transform dialect symbol is consumed or only read.
+      constexpr const static llvm::StringLiteral
+          kArgConsumedAttrName = "transform.consumed";
+      constexpr const static llvm::StringLiteral
+          kArgReadOnlyAttrName = "transform.readonly";
+
       /// Returns the named PDL constraint functions available in the dialect
       /// as a map from their name to the function.
       const ::llvm::StringMap<::mlir::PDLConstraintFunction> &
@@ -114,7 +121,7 @@ def Transform_Dialect : Dialect {
   }];
 }
 
-// Base class for ops that belong to the tranfsorm dialect. Ops defined in
+// Base class for ops that belong to the transform dialect. Ops defined in
 // extensions of this dialect may also use this.
 class TransformDialectOp<string mnemonic, list<Trait> traits = []>
     : Op<Transform_Dialect, mnemonic, traits>;
index 41b0840..9f4e3d8 100644 (file)
@@ -847,6 +847,11 @@ bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
 void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
 void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
 
+/// Populates `consumedArguments` with positions of `block` arguments that are
+/// consumed by the operations in the `block`.
+void getConsumedBlockArguments(
+    Block &block, llvm::SmallDenseSet<unsigned> &consumedArguments);
+
 /// Trait implementing the MemoryEffectOpInterface for operations that "consume"
 /// their operands and produce new results.
 template <typename OpTy>
index f19570f..7a7dfe4 100644 (file)
@@ -1,4 +1,4 @@
-//===- CheckUses.h - Expensive transform value validity checks --*- C++ -*-===//
+//===- Passes.h - Transform dialect pass entry points -----------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
index 4fdd2e3..2400066 100644 (file)
@@ -32,4 +32,14 @@ def CheckUsesPass : Pass<"transform-dialect-check-uses"> {
   }];
 }
 
+def InferEffectsPass : Pass<"transform-infer-effects"> {
+  let summary = "infer transform side effects for symbols";
+  let description = [{
+    This pass analyzes the definitions of transform dialect callable symbol
+    operations, such as `transform.named_sequence`, and annotates the symbol
+    arguments with attributes indicating the side effects that the nested
+    operations have on them.
+  }];
+}
+
 #endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_PASSES
index 99ff80e..d4578e0 100644 (file)
@@ -175,6 +175,14 @@ LogicalResult transform::TransformDialect::verifyOperationAttribute(
     }
     return success();
   }
+  if (attribute.getName().getValue() == kArgConsumedAttrName ||
+      attribute.getName().getValue() == kArgReadOnlyAttrName) {
+    if (!attribute.getValue().isa<UnitAttr>()) {
+      return op->emitError()
+             << attribute.getName() << " must be a unit attribute";
+    }
+    return success();
+  }
   return emitError(op->getLoc())
          << "unknown attribute: " << attribute.getName();
 }
index c9d28f9..c4e868e 100644 (file)
@@ -1318,6 +1318,29 @@ void transform::onlyReadsPayload(
   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
 }
 
+void transform::getConsumedBlockArguments(
+    Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
+  SmallVector<MemoryEffects::EffectInstance> effects;
+  for (Operation &nested : block) {
+    auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
+    if (!iface)
+      continue;
+
+    effects.clear();
+    iface.getEffects(effects);
+    for (const MemoryEffects::EffectInstance &effect : effects) {
+      BlockArgument argument =
+          dyn_cast_or_null<BlockArgument>(effect.getValue());
+      if (!argument || argument.getOwner() != &block ||
+          !isa<MemoryEffects::Free>(effect.getEffect()) ||
+          effect.getResource() != transform::TransformMappingResource::get()) {
+        continue;
+      }
+      consumedArguments.insert(argument.getArgNumber());
+    }
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Utilities for TransformOpInterface.
 //===----------------------------------------------------------------------===//
index a37822d..c4f5769 100644 (file)
@@ -720,32 +720,44 @@ verifyNamedSequenceOp(transform::NamedSequenceOp op);
 
 void transform::IncludeOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  // Always mark as modifying the payload.
+  // TODO: a mechanism to annotate effects on payload. Even when all handles are
+  // only read, the payload may still be modified, so we currently stay on the
+  // conservative side and always indicate modification. This may prevent some
+  // code reordering.
+  modifiesPayload(effects);
+
+  // Results are always produced.
+  producesHandle(getResults(), effects);
+
+  // Adds default effects to operands and results. This will be added if
+  // preconditions fail so the trait verifier doesn't complain about missing
+  // effects and the real precondition failure is reported later on.
+  auto defaultEffects = [&] { onlyReadsHandle(getOperands(), effects); };
+
   // Bail if the callee is unknown. This may run as part of the verification
   // process before we verified the validity of the callee or of this op.
   auto target =
       getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
   if (!target)
-    return;
+    return defaultEffects();
   auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
       getOperation(), getTarget());
   if (!callee)
-    return;
+    return defaultEffects();
   DiagnosedSilenceableFailure earlyVerifierResult =
       verifyNamedSequenceOp(callee);
   if (!earlyVerifierResult.succeeded()) {
     (void)earlyVerifierResult.silence();
-    return;
+    return defaultEffects();
   }
 
-  // Carry over effects from the callee.
-  // TODO: external callees must provides attributes annotating the
-  // readonly/consume effects on operands.
-  if (!callee.isExternal())
-    remapArgumentEffects(callee.getBody().front(), getOperands(), effects);
-
-  // Proper effects.
-  onlyReadsHandle(getOperands(), effects);
-  producesHandle(getResults(), effects);
+  for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
+    if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
+      consumesHandle(getOperand(i), effects);
+    else
+      onlyReadsHandle(getOperand(i), effects);
+  }
 }
 
 template <typename... Tys>
@@ -753,6 +765,52 @@ static bool implementSameInterface(Type t1, Type t2) {
   return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
 }
 
+/// Checks that the attributes of the named sequence operation have correct
+/// consumption effect annotations. If `alsoVerifyInternal`, checks for
+/// annotations being present even if they can be inferred from the body.
+static DiagnosedSilenceableFailure
+verifyNamedSequenceConsumeAnnotations(transform::NamedSequenceOp op,
+                                      bool alsoVerifyInternal = false) {
+  llvm::SmallDenseSet<unsigned> consumedArguments;
+  if (!op.isExternal()) {
+    transform::getConsumedBlockArguments(op.getBody().front(),
+                                         consumedArguments);
+  }
+  for (unsigned i = 0, e = op.getFunctionType().getNumInputs(); i < e; ++i) {
+    bool isConsumed =
+        op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
+        nullptr;
+    bool isReadOnly =
+        op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
+        nullptr;
+    if (isConsumed && isReadOnly) {
+      return op.emitSilenceableError()
+             << "argument #" << i << " cannot be both readonly and consumed";
+    }
+    if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
+      return op.emitSilenceableError()
+             << "must provide consumed/readonly status for arguments of "
+                "external or called ops";
+    }
+    if (op.isExternal())
+      continue;
+
+    if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
+      return op.emitSilenceableError()
+             << "argument #" << i
+             << " is consumed in the body but is not marked as such";
+    }
+    if (!consumedArguments.contains(i) && isConsumed) {
+      Diagnostic warning(op->getLoc(), DiagnosticSeverity::Warning);
+      warning << "argument #" << i
+              << " is not consumed in the body but is marked as consumed";
+      return DiagnosedSilenceableFailure::silenceableFailure(
+          std::move(warning));
+    }
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
 LogicalResult
 transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   // Access through indirection and do additional checking because this may be
@@ -794,7 +852,9 @@ transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
     }
   }
 
-  return success();
+  return verifyNamedSequenceConsumeAnnotations(target,
+                                               /*alsoVerifyInternal=*/true)
+      .checkAndReport();
 }
 
 //===----------------------------------------------------------------------===//
@@ -899,7 +959,7 @@ verifyNamedSequenceOp(transform::NamedSequenceOp op) {
   }
 
   if (op.isExternal() || op.getBody().empty())
-    return DiagnosedSilenceableFailure::success();
+    return verifyNamedSequenceConsumeAnnotations(op);
 
   if (op.getBody().front().empty())
     return emitSilenceableFailure(op) << "expected a non-empty body block";
@@ -931,7 +991,7 @@ verifyNamedSequenceOp(transform::NamedSequenceOp op) {
            << operandType << " vs " << resultType << ")";
   }
 
-  return DiagnosedSilenceableFailure::success();
+  return verifyNamedSequenceConsumeAnnotations(op);
 }
 
 LogicalResult transform::NamedSequenceOp::verify() {
index bf9a255..68b363d 100644 (file)
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRTransformDialectTransforms
   CheckUses.cpp
+  InferEffects.cpp
   TransformInterpreterPassBase.cpp
 
   DEPENDS
diff --git a/mlir/lib/Dialect/Transform/Transforms/InferEffects.cpp b/mlir/lib/Dialect/Transform/Transforms/InferEffects.cpp
new file mode 100644 (file)
index 0000000..461ae9b
--- /dev/null
@@ -0,0 +1,69 @@
+//===- InferEffects.cpp - Infer memory effects for named symbols ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Transforms/Passes.h"
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/FunctionInterfaces.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/DenseSet.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace transform {
+#define GEN_PASS_DEF_INFEREFFECTSPASS
+#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
+} // namespace transform
+} // namespace mlir
+
+static LogicalResult inferSideEffectAnnotations(Operation *op) {
+  if (!isa<transform::TransformOpInterface>(op))
+    return success();
+
+  auto func = dyn_cast<FunctionOpInterface>(op);
+  if (!func || func.isExternal())
+    return success();
+
+  if (!func.getFunctionBody().hasOneBlock()) {
+    return op->emitError()
+           << "only single-block operations are currently supported";
+  }
+
+  // Note that there can't be an inclusion of an unannotated symbol because it
+  // wouldn't have passed the verifier, so recursion isn't necessary here.
+  llvm::SmallDenseSet<unsigned> consumedArguments;
+  transform::getConsumedBlockArguments(func.getFunctionBody().front(),
+                                       consumedArguments);
+
+  for (unsigned i = 0, e = func.getNumArguments(); i < e; ++i) {
+    func.setArgAttr(i,
+                    consumedArguments.contains(i)
+                        ? transform::TransformDialect::kArgConsumedAttrName
+                        : transform::TransformDialect::kArgReadOnlyAttrName,
+                    UnitAttr::get(op->getContext()));
+  }
+  return success();
+}
+
+namespace {
+class InferEffectsPass
+    : public transform::impl::InferEffectsPassBase<InferEffectsPass> {
+public:
+  void runOnOperation() override {
+    WalkResult result = getOperation()->walk([](Operation *op) {
+      return failed(inferSideEffectAnnotations(op)) ? WalkResult::interrupt()
+                                                    : WalkResult::advance();
+    });
+    if (result.wasInterrupted())
+      return signalPassFailure();
+  }
+};
+} // namespace
index b3fe45e..1f651dc 100644 (file)
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/FunctionInterfaces.h"
@@ -298,6 +299,12 @@ static void performOptionalDebugActions(
 /// Replaces external symbols in `block` with their (non-external) definitions
 /// from the given module.
 static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
+  MLIRContext &ctx = *definitions->getContext();
+  auto consumedName =
+      StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName);
+  auto readOnlyName =
+      StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
+
   for (Operation &op : llvm::make_early_inc_range(block)) {
     LLVM_DEBUG(DBGS() << op << "\n");
     auto symbol = dyn_cast<SymbolOpInterface>(op);
@@ -330,6 +337,30 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
              << externalSymbolFunc.getFunctionType() << ")";
     }
 
+    for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) {
+      bool isExternalConsumed =
+          externalSymbolFunc.getArgAttr(i, consumedName) != nullptr;
+      bool isExternalReadonly =
+          externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr;
+      bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr;
+      bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr;
+      if (!isExternalConsumed && !isExternalReadonly) {
+        if (isConsumed)
+          externalSymbolFunc.setArgAttr(i, consumedName, UnitAttr::get(&ctx));
+        else if (isReadonly)
+          externalSymbolFunc.setArgAttr(i, readOnlyName, UnitAttr::get(&ctx));
+        continue;
+      }
+
+      if ((isExternalConsumed && !isConsumed) ||
+          (isExternalReadonly && !isReadonly)) {
+        return symbolFunc.emitError()
+               << "external definition has mismatching consumption annotations "
+                  "for argument #"
+               << i;
+      }
+    }
+
     OpBuilder builder(&op);
     builder.setInsertionPoint(&op);
     builder.clone(*externalSymbol);
diff --git a/mlir/test/Dialect/Transform/infer-effects.mlir b/mlir/test/Dialect/Transform/infer-effects.mlir
new file mode 100644 (file)
index 0000000..05c6a5a
--- /dev/null
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s --transform-infer-effects | FileCheck %s
+
+module attributes { transform.with_named_sequence } {
+  // CHECK-LABEL: @infer
+  // CHECK-SAME: %{{.*}}: !transform.any_op {transform.consumed}
+  // CHECK-SAME: %{{.*}}: !transform.any_op {transform.readonly}
+  // CHECK-SAME: %{{.*}}: !transform.param<i32> {transform.readonly}
+  transform.named_sequence @infer(%op: !transform.any_op, %other: !transform.any_op, %param: !transform.param<i32>) {
+    transform.test_consume_operand %op : !transform.any_op
+    transform.test_print_remark_at_operand %other, "" : !transform.any_op
+    transform.yield
+  }
+}
index a6bfa64..df2792a 100644 (file)
@@ -467,3 +467,98 @@ module attributes { transform.with_named_sequence} {
     transform.yield %arg0 : !transform.any_op
   }
 }
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  // expected-error @below {{must provide consumed/readonly status for arguments of external or called ops}}
+  transform.named_sequence @foo(%op: !transform.any_op )
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  // expected-error @below {{argument #0 cannot be both readonly and consumed}}
+  transform.named_sequence @foo(%op: !transform.any_op { transform.readonly, transform.consumed } )
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  // expected-error @below {{must provide consumed/readonly status for arguments of external or called ops}}
+  transform.named_sequence @foo(%op: !transform.any_op) {
+    transform.test_print_remark_at_operand %op, "message" : !transform.any_op
+    transform.yield
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  // expected-error @below {{argument #0 cannot be both readonly and consumed}}
+  transform.named_sequence @foo(%op: !transform.any_op {transform.readonly, transform.consumed}) {
+    transform.test_print_remark_at_operand %op, "message" : !transform.any_op
+    transform.yield
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  // expected-warning @below {{argument #0 is not consumed in the body but is marked as consume}}
+  transform.named_sequence @foo(%op: !transform.any_op {transform.consumed}) {
+    transform.test_print_remark_at_operand %op, "message" : !transform.any_op
+    transform.yield
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  // expected-error @below {{argument #0 is consumed in the body but is not marked as such}}
+  transform.named_sequence @foo(%op: !transform.any_op {transform.readonly}) {
+    transform.test_consume_operand %op : !transform.any_op
+    transform.yield
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
+}
+
+// -----
+
+// Checking that consumptions annotations are used correctly in invocation checks.
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @foo(%op: !transform.any_op { transform.consumed } )
+
+  // expected-error @below {{'transform.sequence' op block argument #0 has more than one potential consumer}}
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    // expected-note @below {{used here as operand #0}}
+    transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    // expected-note @below {{used here as operand #0}}
+    transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    transform.yield
+  }
+}
index da50c6b..3d4cb07 100644 (file)
@@ -10,4 +10,5 @@
 // produced twice at the same location only needs to be matched once.
 
 // expected-remark @below {{message}}
+// expected-remark @below {{unannotated}}
 module {}
index bb8acf8..b21abbb 100644 (file)
@@ -6,7 +6,7 @@
 
 module attributes {transform.with_named_sequence} {
   // expected-error @below {{external definition has a mismatching signature}}
-  transform.named_sequence private @foo(!transform.op<"builtin.module">)
+  transform.named_sequence private @foo(!transform.op<"builtin.module"> {transform.readonly})
 
   transform.sequence failures(propagate) {
   ^bb0(%arg0: !transform.op<"builtin.module">):
@@ -25,3 +25,15 @@ module attributes {transform.with_named_sequence} {
     include @undefined_sequence failures(suppress) () : () -> ()
   }
 }
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+  // expected-error @below {{external definition has mismatching consumption annotations for argument #0}}
+  transform.named_sequence private @consuming(%arg0: !transform.any_op {transform.readonly})
+
+  transform.sequence failures(suppress) {
+  ^bb0(%arg0: !transform.any_op):
+    include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
+  }
+}
index 6e23641..04b6c5a 100644 (file)
@@ -1,11 +1,11 @@
 // RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
-// RUN:             --verify-diagnostics | FileCheck %s
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
 // RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
-// RUN:             --verify-diagnostics | FileCheck %s
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
 // RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
-// RUN:             --verify-diagnostics | FileCheck %s
+// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
 // The definition of the @foo named sequence is provided in another file. It
 // will be included because of the pass option. Repeated application of the
 // needs to be matched once.
 
 // expected-remark @below {{message}}
+// expected-remark @below {{unannotated}}
 module attributes {transform.with_named_sequence} {
   // CHECK: transform.named_sequence @foo
   // CHECK: test_print_remark_at_operand %{{.*}}, "message"
-  transform.named_sequence private @foo(!transform.any_op)
+  transform.named_sequence private @foo(!transform.any_op {transform.readonly})
+
+  // CHECK: transform.named_sequence @unannotated
+  // CHECK: test_print_remark_at_operand %{{.*}}, "unannotated"
+  transform.named_sequence private @unannotated(!transform.any_op {transform.readonly})
 
   transform.sequence failures(propagate) {
   ^bb0(%arg0: !transform.any_op):
     include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+    include @unannotated failures(propagate) (%arg0) : (!transform.any_op) -> ()
   }
 }
index 509612b..1149bda 100644 (file)
@@ -1,8 +1,18 @@
 // RUN: mlir-opt %s
 
 module attributes {transform.with_named_sequence} {
-  transform.named_sequence @foo(%arg0: !transform.any_op) {
+  transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) {
     transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
     transform.yield
   }
+
+  transform.named_sequence @consuming(%arg0: !transform.any_op {transform.consumed}) {
+    transform.test_consume_operand %arg0 : !transform.any_op
+    transform.yield
+  }
+
+  transform.named_sequence @unannotated(%arg0: !transform.any_op) {
+    transform.test_print_remark_at_operand %arg0, "unannotated" : !transform.any_op
+    transform.yield
+  }
 }
index 6b2b0dd..3c2b9b0 100644 (file)
@@ -1260,7 +1260,7 @@ transform.sequence failures(propagate) {
 
 module @named_inclusion attributes { transform.with_named_sequence } {
 
-  transform.named_sequence @foo(%arg0: !transform.any_op) -> () {
+  transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> () {
     // expected-remark @below {{applying transformation "a"}}
     transform.test_transform_op "a"
     transform.yield
@@ -1276,13 +1276,13 @@ module @named_inclusion attributes { transform.with_named_sequence } {
 
 module @named_inclusion_in_named attributes { transform.with_named_sequence } {
 
-  transform.named_sequence @foo(%arg0: !transform.any_op) -> () {
+  transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> () {
     // expected-remark @below {{applying transformation "a"}}
     transform.test_transform_op "a"
     transform.yield
   }
 
-  transform.named_sequence @bar(%arg0: !transform.any_op) -> () {
+  transform.named_sequence @bar(%arg0: !transform.any_op {transform.readonly}) -> () {
     // expected-remark @below {{applying transformation "b"}}
     transform.test_transform_op "b"
     transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
@@ -1300,7 +1300,8 @@ module @named_inclusion_in_named attributes { transform.with_named_sequence } {
 // expected-remark @below {{operation}}
 module @named_operands attributes { transform.with_named_sequence } {
 
-  transform.named_sequence @foo(%arg0: !transform.any_op, %arg1: !transform.any_value) -> () {
+  transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly},
+                                %arg1: !transform.any_value {transform.readonly}) -> () {
     transform.test_print_remark_at_operand %arg0, "operation" : !transform.any_op
     transform.test_print_remark_at_operand_value %arg1, "value" : !transform.any_value
     transform.yield
@@ -1322,7 +1323,7 @@ module @named_return attributes { transform.with_named_sequence } {
 
   // expected-remark @below {{value}}
   // expected-note @below {{value handle points to a block argument #0 in block #0 in region #0}}
-  transform.named_sequence @foo(%arg0: !transform.any_op) -> (!transform.any_op, !transform.any_value) {
+  transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_value) {
     %0 = transform.test_produce_value_handle_to_self_operand %arg0 : (!transform.any_op) -> !transform.any_value
     transform.yield %arg0, %0 : !transform.any_op, !transform.any_value
   }