constexpr const static llvm::StringLiteral
kTargetTagAttrName = "transform.target_tag";
+ /// Names of the attributes indicating whether an argument of an external
+ /// transform dialect symbol is consumed or only read.
+ constexpr const static llvm::StringLiteral
+ kArgConsumedAttrName = "transform.consumed";
+ constexpr const static llvm::StringLiteral
+ kArgReadOnlyAttrName = "transform.readonly";
+
/// Returns the named PDL constraint functions available in the dialect
/// as a map from their name to the function.
const ::llvm::StringMap<::mlir::PDLConstraintFunction> &
}];
}
-// Base class for ops that belong to the tranfsorm dialect. Ops defined in
+// Base class for ops that belong to the transform dialect. Ops defined in
// extensions of this dialect may also use this.
class TransformDialectOp<string mnemonic, list<Trait> traits = []>
: Op<Transform_Dialect, mnemonic, traits>;
void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
+/// Populates `consumedArguments` with positions of `block` arguments that are
+/// consumed by the operations in the `block`.
+void getConsumedBlockArguments(
+ Block &block, llvm::SmallDenseSet<unsigned> &consumedArguments);
+
/// Trait implementing the MemoryEffectOpInterface for operations that "consume"
/// their operands and produce new results.
template <typename OpTy>
-//===- CheckUses.h - Expensive transform value validity checks --*- C++ -*-===//
+//===- Passes.h - Transform dialect pass entry points -----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
}];
}
+def InferEffectsPass : Pass<"transform-infer-effects"> {
+ let summary = "infer transform side effects for symbols";
+ let description = [{
+ This pass analyzes the definitions of transform dialect callable symbol
+ operations, such as `transform.named_sequence`, and annotates the symbol
+ arguments with attributes indicating the side effects that the nested
+ operations have on them.
+ }];
+}
+
#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_PASSES
}
return success();
}
+ if (attribute.getName().getValue() == kArgConsumedAttrName ||
+ attribute.getName().getValue() == kArgReadOnlyAttrName) {
+ if (!attribute.getValue().isa<UnitAttr>()) {
+ return op->emitError()
+ << attribute.getName() << " must be a unit attribute";
+ }
+ return success();
+ }
return emitError(op->getLoc())
<< "unknown attribute: " << attribute.getName();
}
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
}
+void transform::getConsumedBlockArguments(
+ Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
+ SmallVector<MemoryEffects::EffectInstance> effects;
+ for (Operation &nested : block) {
+ auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
+ if (!iface)
+ continue;
+
+ effects.clear();
+ iface.getEffects(effects);
+ for (const MemoryEffects::EffectInstance &effect : effects) {
+ BlockArgument argument =
+ dyn_cast_or_null<BlockArgument>(effect.getValue());
+ if (!argument || argument.getOwner() != &block ||
+ !isa<MemoryEffects::Free>(effect.getEffect()) ||
+ effect.getResource() != transform::TransformMappingResource::get()) {
+ continue;
+ }
+ consumedArguments.insert(argument.getArgNumber());
+ }
+ }
+}
+
//===----------------------------------------------------------------------===//
// Utilities for TransformOpInterface.
//===----------------------------------------------------------------------===//
void transform::IncludeOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ // Always mark as modifying the payload.
+ // TODO: a mechanism to annotate effects on payload. Even when all handles are
+ // only read, the payload may still be modified, so we currently stay on the
+ // conservative side and always indicate modification. This may prevent some
+ // code reordering.
+ modifiesPayload(effects);
+
+ // Results are always produced.
+ producesHandle(getResults(), effects);
+
+ // Adds default effects to operands and results. This will be added if
+ // preconditions fail so the trait verifier doesn't complain about missing
+ // effects and the real precondition failure is reported later on.
+ auto defaultEffects = [&] { onlyReadsHandle(getOperands(), effects); };
+
// Bail if the callee is unknown. This may run as part of the verification
// process before we verified the validity of the callee or of this op.
auto target =
getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
if (!target)
- return;
+ return defaultEffects();
auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
getOperation(), getTarget());
if (!callee)
- return;
+ return defaultEffects();
DiagnosedSilenceableFailure earlyVerifierResult =
verifyNamedSequenceOp(callee);
if (!earlyVerifierResult.succeeded()) {
(void)earlyVerifierResult.silence();
- return;
+ return defaultEffects();
}
- // Carry over effects from the callee.
- // TODO: external callees must provides attributes annotating the
- // readonly/consume effects on operands.
- if (!callee.isExternal())
- remapArgumentEffects(callee.getBody().front(), getOperands(), effects);
-
- // Proper effects.
- onlyReadsHandle(getOperands(), effects);
- producesHandle(getResults(), effects);
+ for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
+ if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
+ consumesHandle(getOperand(i), effects);
+ else
+ onlyReadsHandle(getOperand(i), effects);
+ }
}
template <typename... Tys>
return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
}
+/// Checks that the attributes of the named sequence operation have correct
+/// consumption effect annotations. If `alsoVerifyInternal`, checks for
+/// annotations being present even if they can be inferred from the body.
+static DiagnosedSilenceableFailure
+verifyNamedSequenceConsumeAnnotations(transform::NamedSequenceOp op,
+ bool alsoVerifyInternal = false) {
+ llvm::SmallDenseSet<unsigned> consumedArguments;
+ if (!op.isExternal()) {
+ transform::getConsumedBlockArguments(op.getBody().front(),
+ consumedArguments);
+ }
+ for (unsigned i = 0, e = op.getFunctionType().getNumInputs(); i < e; ++i) {
+ bool isConsumed =
+ op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
+ nullptr;
+ bool isReadOnly =
+ op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
+ nullptr;
+ if (isConsumed && isReadOnly) {
+ return op.emitSilenceableError()
+ << "argument #" << i << " cannot be both readonly and consumed";
+ }
+ if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
+ return op.emitSilenceableError()
+ << "must provide consumed/readonly status for arguments of "
+ "external or called ops";
+ }
+ if (op.isExternal())
+ continue;
+
+ if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
+ return op.emitSilenceableError()
+ << "argument #" << i
+ << " is consumed in the body but is not marked as such";
+ }
+ if (!consumedArguments.contains(i) && isConsumed) {
+ Diagnostic warning(op->getLoc(), DiagnosticSeverity::Warning);
+ warning << "argument #" << i
+ << " is not consumed in the body but is marked as consumed";
+ return DiagnosedSilenceableFailure::silenceableFailure(
+ std::move(warning));
+ }
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
LogicalResult
transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Access through indirection and do additional checking because this may be
}
}
- return success();
+ return verifyNamedSequenceConsumeAnnotations(target,
+ /*alsoVerifyInternal=*/true)
+ .checkAndReport();
}
//===----------------------------------------------------------------------===//
}
if (op.isExternal() || op.getBody().empty())
- return DiagnosedSilenceableFailure::success();
+ return verifyNamedSequenceConsumeAnnotations(op);
if (op.getBody().front().empty())
return emitSilenceableFailure(op) << "expected a non-empty body block";
<< operandType << " vs " << resultType << ")";
}
- return DiagnosedSilenceableFailure::success();
+ return verifyNamedSequenceConsumeAnnotations(op);
}
LogicalResult transform::NamedSequenceOp::verify() {
add_mlir_dialect_library(MLIRTransformDialectTransforms
CheckUses.cpp
+ InferEffects.cpp
TransformInterpreterPassBase.cpp
DEPENDS
--- /dev/null
+//===- InferEffects.cpp - Infer memory effects for named symbols ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Transforms/Passes.h"
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/FunctionInterfaces.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/DenseSet.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace transform {
+#define GEN_PASS_DEF_INFEREFFECTSPASS
+#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
+} // namespace transform
+} // namespace mlir
+
+static LogicalResult inferSideEffectAnnotations(Operation *op) {
+ if (!isa<transform::TransformOpInterface>(op))
+ return success();
+
+ auto func = dyn_cast<FunctionOpInterface>(op);
+ if (!func || func.isExternal())
+ return success();
+
+ if (!func.getFunctionBody().hasOneBlock()) {
+ return op->emitError()
+ << "only single-block operations are currently supported";
+ }
+
+ // Note that there can't be an inclusion of an unannotated symbol because it
+ // wouldn't have passed the verifier, so recursion isn't necessary here.
+ llvm::SmallDenseSet<unsigned> consumedArguments;
+ transform::getConsumedBlockArguments(func.getFunctionBody().front(),
+ consumedArguments);
+
+ for (unsigned i = 0, e = func.getNumArguments(); i < e; ++i) {
+ func.setArgAttr(i,
+ consumedArguments.contains(i)
+ ? transform::TransformDialect::kArgConsumedAttrName
+ : transform::TransformDialect::kArgReadOnlyAttrName,
+ UnitAttr::get(op->getContext()));
+ }
+ return success();
+}
+
+namespace {
+class InferEffectsPass
+ : public transform::impl::InferEffectsPassBase<InferEffectsPass> {
+public:
+ void runOnOperation() override {
+ WalkResult result = getOperation()->walk([](Operation *op) {
+ return failed(inferSideEffectAnnotations(op)) ? WalkResult::interrupt()
+ : WalkResult::advance();
+ });
+ if (result.wasInterrupted())
+ return signalPassFailure();
+ }
+};
+} // namespace
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/FunctionInterfaces.h"
/// Replaces external symbols in `block` with their (non-external) definitions
/// from the given module.
static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
+ MLIRContext &ctx = *definitions->getContext();
+ auto consumedName =
+ StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName);
+ auto readOnlyName =
+ StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
+
for (Operation &op : llvm::make_early_inc_range(block)) {
LLVM_DEBUG(DBGS() << op << "\n");
auto symbol = dyn_cast<SymbolOpInterface>(op);
<< externalSymbolFunc.getFunctionType() << ")";
}
+ for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) {
+ bool isExternalConsumed =
+ externalSymbolFunc.getArgAttr(i, consumedName) != nullptr;
+ bool isExternalReadonly =
+ externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr;
+ bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr;
+ bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr;
+ if (!isExternalConsumed && !isExternalReadonly) {
+ if (isConsumed)
+ externalSymbolFunc.setArgAttr(i, consumedName, UnitAttr::get(&ctx));
+ else if (isReadonly)
+ externalSymbolFunc.setArgAttr(i, readOnlyName, UnitAttr::get(&ctx));
+ continue;
+ }
+
+ if ((isExternalConsumed && !isConsumed) ||
+ (isExternalReadonly && !isReadonly)) {
+ return symbolFunc.emitError()
+ << "external definition has mismatching consumption annotations "
+ "for argument #"
+ << i;
+ }
+ }
+
OpBuilder builder(&op);
builder.setInsertionPoint(&op);
builder.clone(*externalSymbol);
--- /dev/null
+// RUN: mlir-opt %s --transform-infer-effects | FileCheck %s
+
+module attributes { transform.with_named_sequence } {
+ // CHECK-LABEL: @infer
+ // CHECK-SAME: %{{.*}}: !transform.any_op {transform.consumed}
+ // CHECK-SAME: %{{.*}}: !transform.any_op {transform.readonly}
+ // CHECK-SAME: %{{.*}}: !transform.param<i32> {transform.readonly}
+ transform.named_sequence @infer(%op: !transform.any_op, %other: !transform.any_op, %param: !transform.param<i32>) {
+ transform.test_consume_operand %op : !transform.any_op
+ transform.test_print_remark_at_operand %other, "" : !transform.any_op
+ transform.yield
+ }
+}
transform.yield %arg0 : !transform.any_op
}
}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ // expected-error @below {{must provide consumed/readonly status for arguments of external or called ops}}
+ transform.named_sequence @foo(%op: !transform.any_op )
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ // expected-error @below {{argument #0 cannot be both readonly and consumed}}
+ transform.named_sequence @foo(%op: !transform.any_op { transform.readonly, transform.consumed } )
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ // expected-error @below {{must provide consumed/readonly status for arguments of external or called ops}}
+ transform.named_sequence @foo(%op: !transform.any_op) {
+ transform.test_print_remark_at_operand %op, "message" : !transform.any_op
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ // expected-error @below {{argument #0 cannot be both readonly and consumed}}
+ transform.named_sequence @foo(%op: !transform.any_op {transform.readonly, transform.consumed}) {
+ transform.test_print_remark_at_operand %op, "message" : !transform.any_op
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ // expected-warning @below {{argument #0 is not consumed in the body but is marked as consume}}
+ transform.named_sequence @foo(%op: !transform.any_op {transform.consumed}) {
+ transform.test_print_remark_at_operand %op, "message" : !transform.any_op
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ // expected-error @below {{argument #0 is consumed in the body but is not marked as such}}
+ transform.named_sequence @foo(%op: !transform.any_op {transform.readonly}) {
+ transform.test_consume_operand %op : !transform.any_op
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+// -----
+
+// Checking that consumptions annotations are used correctly in invocation checks.
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @foo(%op: !transform.any_op { transform.consumed } )
+
+ // expected-error @below {{'transform.sequence' op block argument #0 has more than one potential consumer}}
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ // expected-note @below {{used here as operand #0}}
+ transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ // expected-note @below {{used here as operand #0}}
+ transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
// produced twice at the same location only needs to be matched once.
// expected-remark @below {{message}}
+// expected-remark @below {{unannotated}}
module {}
module attributes {transform.with_named_sequence} {
// expected-error @below {{external definition has a mismatching signature}}
- transform.named_sequence private @foo(!transform.op<"builtin.module">)
+ transform.named_sequence private @foo(!transform.op<"builtin.module"> {transform.readonly})
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.op<"builtin.module">):
include @undefined_sequence failures(suppress) () : () -> ()
}
}
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+ // expected-error @below {{external definition has mismatching consumption annotations for argument #0}}
+ transform.named_sequence private @consuming(%arg0: !transform.any_op {transform.readonly})
+
+ transform.sequence failures(suppress) {
+ ^bb0(%arg0: !transform.any_op):
+ include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
+ }
+}
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
-// RUN: --verify-diagnostics | FileCheck %s
+// RUN: --verify-diagnostics --split-input-file | FileCheck %s
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
-// RUN: --verify-diagnostics | FileCheck %s
+// RUN: --verify-diagnostics --split-input-file | FileCheck %s
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
-// RUN: --verify-diagnostics | FileCheck %s
+// RUN: --verify-diagnostics --split-input-file | FileCheck %s
// The definition of the @foo named sequence is provided in another file. It
// will be included because of the pass option. Repeated application of the
// needs to be matched once.
// expected-remark @below {{message}}
+// expected-remark @below {{unannotated}}
module attributes {transform.with_named_sequence} {
// CHECK: transform.named_sequence @foo
// CHECK: test_print_remark_at_operand %{{.*}}, "message"
- transform.named_sequence private @foo(!transform.any_op)
+ transform.named_sequence private @foo(!transform.any_op {transform.readonly})
+
+ // CHECK: transform.named_sequence @unannotated
+ // CHECK: test_print_remark_at_operand %{{.*}}, "unannotated"
+ transform.named_sequence private @unannotated(!transform.any_op {transform.readonly})
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ include @unannotated failures(propagate) (%arg0) : (!transform.any_op) -> ()
}
}
// RUN: mlir-opt %s
module attributes {transform.with_named_sequence} {
- transform.named_sequence @foo(%arg0: !transform.any_op) {
+ transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) {
transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
transform.yield
}
+
+ transform.named_sequence @consuming(%arg0: !transform.any_op {transform.consumed}) {
+ transform.test_consume_operand %arg0 : !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @unannotated(%arg0: !transform.any_op) {
+ transform.test_print_remark_at_operand %arg0, "unannotated" : !transform.any_op
+ transform.yield
+ }
}
module @named_inclusion attributes { transform.with_named_sequence } {
- transform.named_sequence @foo(%arg0: !transform.any_op) -> () {
+ transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> () {
// expected-remark @below {{applying transformation "a"}}
transform.test_transform_op "a"
transform.yield
module @named_inclusion_in_named attributes { transform.with_named_sequence } {
- transform.named_sequence @foo(%arg0: !transform.any_op) -> () {
+ transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> () {
// expected-remark @below {{applying transformation "a"}}
transform.test_transform_op "a"
transform.yield
}
- transform.named_sequence @bar(%arg0: !transform.any_op) -> () {
+ transform.named_sequence @bar(%arg0: !transform.any_op {transform.readonly}) -> () {
// expected-remark @below {{applying transformation "b"}}
transform.test_transform_op "b"
transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
// expected-remark @below {{operation}}
module @named_operands attributes { transform.with_named_sequence } {
- transform.named_sequence @foo(%arg0: !transform.any_op, %arg1: !transform.any_value) -> () {
+ transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly},
+ %arg1: !transform.any_value {transform.readonly}) -> () {
transform.test_print_remark_at_operand %arg0, "operation" : !transform.any_op
transform.test_print_remark_at_operand_value %arg1, "value" : !transform.any_value
transform.yield
// expected-remark @below {{value}}
// expected-note @below {{value handle points to a block argument #0 in block #0 in region #0}}
- transform.named_sequence @foo(%arg0: !transform.any_op) -> (!transform.any_op, !transform.any_value) {
+ transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_value) {
%0 = transform.test_produce_value_handle_to_self_operand %arg0 : (!transform.any_op) -> !transform.any_value
transform.yield %arg0, %0 : !transform.any_op, !transform.any_value
}