--- /dev/null
+//===- CheckUses.cpp - Expensive transform value validity checks ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a pass that performs expensive opt-in checks for Transform
+// dialect values being potentially used after they have been consumed.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlowAnalysis.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/Transforms/Passes.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/SetOperations.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Returns a reference to a cached set of blocks that are reachable from the
+/// given block via edges computed by the `getNextNodes` function. For example,
+/// if `getNextNodes` returns successors of a block, this will return the set of
+/// reachable blocks; if it returns predecessors of a block, this will return
+/// the set of blocks from which the given block can be reached. The block is
+/// considered reachable form itself only if there is a cycle.
+template <typename FnTy>
+const llvm::SmallPtrSet<Block *, 4> &
+getReachableImpl(Block *block, FnTy getNextNodes,
+ DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> &cache) {
+ auto it = cache.find(block);
+ if (it != cache.end())
+ return it->getSecond();
+
+ llvm::SmallPtrSet<Block *, 4> &reachable = cache[block];
+ SmallVector<Block *> worklist;
+ worklist.push_back(block);
+ while (!worklist.empty()) {
+ Block *current = worklist.pop_back_val();
+ for (Block *predecessor : getNextNodes(current)) {
+ // The block is reachable from its transitive predecessors. Only add
+ // them to the worklist if they weren't already visited.
+ if (reachable.insert(predecessor).second)
+ worklist.push_back(predecessor);
+ }
+ }
+ return reachable;
+}
+
+/// An analysis that identifies whether a value allocated by a Transform op may
+/// be used by another such op after it may have been freed by a third op on
+/// some control flow path. This is conceptually similar to a data flow
+/// analysis, but relies on side effects related to particular values that
+/// currently cannot be modeled by the MLIR data flow analysis framework (also,
+/// the lattice element would be rather expensive as it would need to include
+/// live and/or freed values for each operation).
+///
+/// This analysis is conservatively pessimisic: it will consider that a value
+/// may be freed if it is freed on any possible control flow path between its
+/// allocation and a relevant use, even if the control never actually flows
+/// through the operation that frees the value. It also does not differentiate
+/// between may- (freed on at least one control flow path) and must-free (freed
+/// on all possible control flow paths) because it would require expensive graph
+/// algorithms.
+///
+/// It is intended as an additional non-blocking verification or debugging aid
+/// for ops in the Transform dialect. It leverages the requirement for Transform
+/// dialect ops to implement the MemoryEffectsOpInterface, and expects the
+/// values in the Transform IR to have an allocation effect on the
+/// TransformMappingResource when defined.
+class TransformOpMemFreeAnalysis {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransformOpMemFreeAnalysis)
+
+ /// Computes the analysis for Transform ops nested in the given operation.
+ explicit TransformOpMemFreeAnalysis(Operation *root) {
+ root->walk([&](Operation *op) {
+ if (isa<transform::TransformOpInterface>(op)) {
+ collectFreedValues(op);
+ return WalkResult::skip();
+ }
+ return WalkResult::advance();
+ });
+ }
+
+ /// A list of operations that may be deleting a value. Non-empty list
+ /// contextually converts to boolean "true" value.
+ class PotentialDeleters {
+ public:
+ /// Creates an empty list that corresponds to the value being live.
+ static PotentialDeleters live() { return PotentialDeleters({}); }
+
+ /// Creates a list from the operations that may be deleting the value.
+ static PotentialDeleters maybeFreed(ArrayRef<Operation *> deleters) {
+ return PotentialDeleters(deleters);
+ }
+
+ /// Converts to "true" if there are operations that may be deleting the
+ /// value.
+ explicit operator bool() const { return !deleters.empty(); }
+
+ /// Concatenates the lists of operations that may be deleting the value. The
+ /// value is known to be live if the reuslting list is still empty.
+ PotentialDeleters &operator|=(const PotentialDeleters &other) {
+ llvm::append_range(deleters, other.deleters);
+ return *this;
+ }
+
+ /// Returns the list of ops that may be deleting the value.
+ ArrayRef<Operation *> getOps() const { return deleters; }
+
+ private:
+ /// Constructs the list from the given operations.
+ explicit PotentialDeleters(ArrayRef<Operation *> ops) {
+ llvm::append_range(deleters, ops);
+ }
+
+ /// The list of operations that may be deleting the value.
+ SmallVector<Operation *> deleters;
+ };
+
+ /// Returns the list of operations that may be deleting the operand value on
+ /// any control flow path between the definition of the value and its use as
+ /// the given operand. For the purposes of this analysis, the value is
+ /// considered to be allocated at its definition point and never re-allocated.
+ PotentialDeleters isUseLive(OpOperand &operand) {
+ const llvm::SmallPtrSet<Operation *, 2> &deleters = freedBy[operand.get()];
+ if (deleters.empty())
+ return live();
+
+#ifndef NDEBUG
+ // Check that the definition point actually allcoates the value.
+ Operation *valueSource =
+ operand.get().isa<OpResult>()
+ ? operand.get().getDefiningOp()
+ : operand.get().getParentBlock()->getParentOp();
+ auto iface = cast<MemoryEffectOpInterface>(valueSource);
+ SmallVector<MemoryEffects::EffectInstance> instances;
+ iface.getEffectsOnResource(transform::TransformMappingResource::get(),
+ instances);
+ assert(hasEffect<MemoryEffects::Allocate>(instances, operand.get()) &&
+ "expected the op defining the value to have an allocation effect "
+ "on it");
+#endif
+
+ // Collect ancestors of the use operation.
+ Block *defBlock = operand.get().getParentBlock();
+ SmallVector<Operation *> ancestors;
+ Operation *ancestor = operand.getOwner();
+ do {
+ ancestors.push_back(ancestor);
+ if (ancestor->getParentRegion() == defBlock->getParent())
+ break;
+ ancestor = ancestor->getParentOp();
+ } while (true);
+ std::reverse(ancestors.begin(), ancestors.end());
+
+ // Consider the control flow from the definition point of the value to its
+ // use point. If the use is located in some nested region, consider the path
+ // from the entry block of the region to the use.
+ for (Operation *ancestor : ancestors) {
+ // The block should be considered partially if it is the block that
+ // contains the definition (allocation) of the value being used, and the
+ // value is defined in the middle of the block, i.e., is not a block
+ // argument.
+ bool isOutermost = ancestor == ancestors.front();
+ bool isFromBlockPartial = isOutermost && operand.get().isa<OpResult>();
+
+ // Check if the value may be freed by operations between its definition
+ // (allocation) point in its block and the terminator of the block or the
+ // ancestor of the use if it is located in the same block. This is only
+ // done for partial blocks here, full blocks will be considered below
+ // similarly to other blocks.
+ if (isFromBlockPartial) {
+ bool defUseSameBlock = ancestor->getBlock() == defBlock;
+ // Consider all ops from the def to its block terminator, except the
+ // when the use is in the same block, in which case only consider the
+ // ops until the user.
+ if (PotentialDeleters potentialDeleters = isFreedInBlockAfter(
+ operand.get().getDefiningOp(), operand.get(),
+ defUseSameBlock ? ancestor : nullptr))
+ return potentialDeleters;
+ }
+
+ // Check if the value may be freed by opeations preceding the ancestor in
+ // its block. Skip the check for partial blocks that contain both the
+ // definition and the use point, as this has been already checked above.
+ if (!isFromBlockPartial || ancestor->getBlock() != defBlock) {
+ if (PotentialDeleters potentialDeleters =
+ isFreedInBlockBefore(ancestor, operand.get()))
+ return potentialDeleters;
+ }
+
+ // Check if the value may be freed by operations in any of the blocks
+ // between the definition point (in the outermost region) or the entry
+ // block of the region (in other regions) and the operand or its ancestor
+ // in the region. This includes the entire "form" block if (1) the block
+ // has not been considered as partial above and (2) the block can be
+ // reached again through some control-flow loop. This includes the entire
+ // "to" block if it can be reached form itself through some control-flow
+ // cycle, regardless of whether it has been visited before.
+ Block *ancestorBlock = ancestor->getBlock();
+ Block *from =
+ isOutermost ? defBlock : &ancestorBlock->getParent()->front();
+ if (PotentialDeleters potentialDeleters =
+ isMaybeFreedOnPaths(from, ancestorBlock, operand.get(),
+ /*alwaysIncludeFrom=*/!isFromBlockPartial))
+ return potentialDeleters;
+ }
+ return live();
+ }
+
+private:
+ /// Make PotentialDeleters constructors available with shorter names.
+ static PotentialDeleters maybeFreed(ArrayRef<Operation *> deleters) {
+ return PotentialDeleters::maybeFreed(deleters);
+ }
+ static PotentialDeleters live() { return PotentialDeleters::live(); }
+
+ /// Returns the list of operations that may be deleting the given value betwen
+ /// the first and last operations, non-inclusive. `getNext` indicates the
+ /// direction of the traversal.
+ PotentialDeleters
+ isFreedBetween(Value value, Operation *first, Operation *last,
+ llvm::function_ref<Operation *(Operation *)> getNext) const {
+ auto it = freedBy.find(value);
+ if (it == freedBy.end())
+ return live();
+ const llvm::SmallPtrSet<Operation *, 2> &deleters = it->getSecond();
+ for (Operation *op = getNext(first); op != last; op = getNext(op)) {
+ if (deleters.contains(op))
+ return maybeFreed(op);
+ }
+ return live();
+ }
+
+ /// Returns the list of operations that may be deleting the given value
+ /// between `root` and `before` values. `root` is expected to be in the same
+ /// block as `before` and precede it. If `before` is null, consider all
+ /// operations until the end of the block including the terminator.
+ PotentialDeleters isFreedInBlockAfter(Operation *root, Value value,
+ Operation *before = nullptr) const {
+ return isFreedBetween(value, root, before,
+ [](Operation *op) { return op->getNextNode(); });
+ }
+
+ /// Returns the list of operations that may be deleting the given value
+ /// between the entry of the block and the `root` operation.
+ PotentialDeleters isFreedInBlockBefore(Operation *root, Value value) const {
+ return isFreedBetween(value, root, nullptr,
+ [](Operation *op) { return op->getPrevNode(); });
+ }
+
+ /// Returns the list of operations that may be deleting the given value on
+ /// any of the control flow paths between the "form" and the "to" block. The
+ /// operations from any block visited on any control flow path are
+ /// consdiered. The "from" block is considered if there is a control flow
+ /// cycle going through it, i.e., if there is a possibility that all
+ /// operations in this block are visited or if the `alwaysIncludeFrom` flag is
+ /// set. The "to" block is considered only if there is a control flow cycle
+ /// going through it.
+ PotentialDeleters isMaybeFreedOnPaths(Block *from, Block *to, Value value,
+ bool alwaysIncludeFrom) {
+ // Find all blocks that lie on any path between "from" and "to", i.e., the
+ // intersection of blocks reachable from "from" and blocks from which "to"
+ // is rechable.
+ const llvm::SmallPtrSet<Block *, 4> &sources = getReachableFrom(to);
+ if (!sources.contains(from))
+ return live();
+
+ llvm::SmallPtrSet<Block *, 4> reachable(getReachable(from));
+ llvm::set_intersect(reachable, sources);
+
+ // If requested, include the "from" block that may not be present in the set
+ // of visited blocks when there is no cycle going through it.
+ if (alwaysIncludeFrom)
+ reachable.insert(from);
+
+ // Join potential deleters from all blocks as we don't know here which of
+ // the paths through the control flow is taken.
+ PotentialDeleters potentialDeleters = live();
+ for (Block *block : reachable) {
+ for (Operation &op : *block) {
+ if (freedBy[value].count(&op))
+ potentialDeleters |= maybeFreed(&op);
+ }
+ }
+ return potentialDeleters;
+ }
+
+ /// Popualtes `reachable` with the set of blocks that are rechable from the
+ /// given block. A block is considered reachable from itself if there is a
+ /// cycle in the control-flow graph that invovles the block.
+ const llvm::SmallPtrSet<Block *, 4> &getReachable(Block *block) {
+ return getReachableImpl(
+ block, [](Block *b) { return b->getSuccessors(); }, reachableCache);
+ }
+
+ /// Populates `sources` with the set of blocks from which the given block is
+ /// reachable.
+ const llvm::SmallPtrSet<Block *, 4> &getReachableFrom(Block *block) {
+ return getReachableImpl(
+ block, [](Block *b) { return b->getPredecessors(); },
+ reachableFromCache);
+ }
+
+ /// Returns true of `instances` contains an effect of `EffectTy` on `value`.
+ template <typename EffectTy>
+ static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> instances,
+ Value value) {
+ return llvm::any_of(instances,
+ [&](const MemoryEffects::EffectInstance &instance) {
+ return instance.getValue() == value &&
+ isa<EffectTy>(instance.getEffect());
+ });
+ }
+
+ /// Records the values that are being freed by an operation or any of its
+ /// children in `freedBy`.
+ void collectFreedValues(Operation *root) {
+ SmallVector<MemoryEffects::EffectInstance> instances;
+ root->walk([&](Operation *child) {
+ // TODO: extend this to conservatively handle operations with undeclared
+ // side effects as maybe freeing the operands.
+ auto iface = cast<MemoryEffectOpInterface>(child);
+ instances.clear();
+ iface.getEffectsOnResource(transform::TransformMappingResource::get(),
+ instances);
+ for (Value operand : child->getOperands()) {
+ if (hasEffect<MemoryEffects::Free>(instances, operand)) {
+ // All parents of the operation that frees a value should be
+ // considered as potentially freeing the value as well.
+ //
+ // TODO: differentiate between must-free/may-free as well as between
+ // this op having the effect and children having the effect. This may
+ // require some analysis of all control flow paths through the nested
+ // regions as well as a mechanism to separate proper side effects from
+ // those obtained by nesting.
+ Operation *parent = child;
+ do {
+ freedBy[operand].insert(parent);
+ if (parent == root)
+ break;
+ parent = parent->getParentOp();
+ } while (true);
+ }
+ }
+ });
+ }
+
+ /// The mapping from a value to operations that have a Free memory effect on
+ /// the TransformMappingResource and associated with this value, or to
+ /// Transform operations transitively containing such operations.
+ DenseMap<Value, llvm::SmallPtrSet<Operation *, 2>> freedBy;
+
+ /// Caches for sets of reachable blocks.
+ DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> reachableCache;
+ DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> reachableFromCache;
+};
+
+#define GEN_PASS_CLASSES
+#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
+
+//// A simple pass that warns about any use of a value by a transform operation
+// that may be using the value after it has been freed.
+class CheckUsesPass : public CheckUsesBase<CheckUsesPass> {
+public:
+ void runOnOperation() override {
+ auto &analysis = getAnalysis<TransformOpMemFreeAnalysis>();
+
+ getOperation()->walk([&](Operation *child) {
+ for (OpOperand &operand : child->getOpOperands()) {
+ TransformOpMemFreeAnalysis::PotentialDeleters deleters =
+ analysis.isUseLive(operand);
+ if (!deleters)
+ continue;
+
+ InFlightDiagnostic diag = child->emitWarning()
+ << "operand #" << operand.getOperandNumber()
+ << " may be used after free";
+ diag.attachNote(operand.get().getLoc()) << "allocated here";
+ for (Operation *d : deleters.getOps()) {
+ diag.attachNote(d->getLoc()) << "freed here";
+ }
+ }
+ });
+ }
+};
+
+} // namespace
+
+namespace mlir {
+namespace transform {
+std::unique_ptr<Pass> createCheckUsesPass() {
+ return std::make_unique<CheckUsesPass>();
+}
+} // namespace transform
+} // namespace mlir
--- /dev/null
+// RUN: mlir-opt %s --transform-dialect-check-uses --split-input-file --verify-diagnostics
+
+func.func @use_after_free_branching_control_flow() {
+ // expected-note @below {{allocated here}}
+ %0 = transform.test_produce_param_or_forward_operand 42
+ transform.test_transform_op_with_regions {
+ "transform.test_branching_transform_op_terminator"() : () -> ()
+ },
+ {
+ ^bb0:
+ "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> ()
+ ^bb1:
+ // expected-note @below {{freed here}}
+ transform.test_consume_operand_if_matches_param_or_fail %0[42]
+ "transform.test_branching_transform_op_terminator"()[^bb3] : () -> ()
+ ^bb2:
+ "transform.test_branching_transform_op_terminator"()[^bb3] : () -> ()
+ ^bb3:
+ // expected-warning @below {{operand #0 may be used after free}}
+ transform.sequence %0 {
+ ^bb0(%arg0: !pdl.operation):
+ }
+ "transform.test_branching_transform_op_terminator"() : () -> ()
+ }
+ return
+}
+
+// -----
+
+func.func @use_after_free_in_nested_op() {
+ // expected-note @below {{allocated here}}
+ %0 = transform.test_produce_param_or_forward_operand 42
+ // expected-note @below {{freed here}}
+ transform.test_transform_op_with_regions {
+ "transform.test_branching_transform_op_terminator"() : () -> ()
+ },
+ {
+ ^bb0:
+ "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> ()
+ ^bb1:
+ transform.test_consume_operand_if_matches_param_or_fail %0[42]
+ "transform.test_branching_transform_op_terminator"()[^bb3] : () -> ()
+ ^bb2:
+ "transform.test_branching_transform_op_terminator"()[^bb3] : () -> ()
+ ^bb3:
+ "transform.test_branching_transform_op_terminator"() : () -> ()
+ }
+ // expected-warning @below {{operand #0 may be used after free}}
+ transform.sequence %0 {
+ ^bb0(%arg0: !pdl.operation):
+ }
+ return
+}
+
+// -----
+
+func.func @use_after_free_recursive_side_effects() {
+ transform.sequence {
+ ^bb0(%arg0: !pdl.operation):
+ // expected-note @below {{allocated here}}
+ %0 = transform.sequence %arg0 attributes { ord = 1 } {
+ ^bb1(%arg1: !pdl.operation):
+ yield %arg1 : !pdl.operation
+ } : !pdl.operation
+ transform.sequence %0 attributes { ord = 2 } {
+ ^bb2(%arg2: !pdl.operation):
+ }
+ transform.sequence %0 attributes { ord = 3 } {
+ ^bb3(%arg3: !pdl.operation):
+ }
+
+ // `transform.sequence` has recursive side effects so it has the same "free"
+ // as the child op it contains.
+ // expected-note @below {{freed here}}
+ transform.sequence %0 attributes { ord = 4 } {
+ ^bb4(%arg4: !pdl.operation):
+ test_consume_operand_if_matches_param_or_fail %0[42]
+ }
+ // expected-warning @below {{operand #0 may be used after free}}
+ transform.sequence %0 attributes { ord = 5 } {
+ ^bb3(%arg3: !pdl.operation):
+ }
+ }
+ return
+}
+
+// -----
+
+func.func @use_after_free() {
+ transform.sequence {
+ ^bb0(%arg0: !pdl.operation):
+ // expected-note @below {{allocated here}}
+ %0 = transform.sequence %arg0 attributes { ord = 1 } {
+ ^bb1(%arg1: !pdl.operation):
+ yield %arg1 : !pdl.operation
+ } : !pdl.operation
+ transform.sequence %0 attributes { ord = 2 } {
+ ^bb2(%arg2: !pdl.operation):
+ }
+ transform.sequence %0 attributes { ord = 3 } {
+ ^bb3(%arg3: !pdl.operation):
+ }
+
+ // expected-note @below {{freed here}}
+ test_consume_operand_if_matches_param_or_fail %0[42]
+ // expected-warning @below {{operand #0 may be used after free}}
+ transform.sequence %0 attributes { ord = 5 } {
+ ^bb3(%arg3: !pdl.operation):
+ }
+ }
+ return
+}
+
+// -----
+
+// In the case of a control flow cycle, the operation that uses the value may
+// precede the one that frees it in the same block. Both operations should
+// be reported as use-after-free.
+func.func @use_after_free_self_cycle() {
+ // expected-note @below {{allocated here}}
+ %0 = transform.test_produce_param_or_forward_operand 42
+ transform.test_transform_op_with_regions {
+ "transform.test_branching_transform_op_terminator"() : () -> ()
+ },
+ {
+ ^bb0:
+ "transform.test_branching_transform_op_terminator"()[^bb1] : () -> ()
+ ^bb1:
+ // expected-warning @below {{operand #0 may be used after free}}
+ transform.sequence %0 {
+ ^bb0(%arg0: !pdl.operation):
+ }
+ // expected-warning @below {{operand #0 may be used after free}}
+ // expected-note @below {{freed here}}
+ transform.test_consume_operand_if_matches_param_or_fail %0[42]
+ "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> ()
+ ^bb2:
+ "transform.test_branching_transform_op_terminator"() : () -> ()
+ }
+ return
+}
+
+
+// -----
+
+// Check that the "free" that happens in a cycle is also reported as potential
+// use-after-free.
+func.func @use_after_free_cycle() {
+ // expected-note @below {{allocated here}}
+ %0 = transform.test_produce_param_or_forward_operand 42
+ transform.test_transform_op_with_regions {
+ "transform.test_branching_transform_op_terminator"() : () -> ()
+ },
+ {
+ ^bb0:
+ "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> ()
+ ^bb1:
+ // expected-warning @below {{operand #0 may be used after free}}
+ // expected-note @below {{freed here}}
+ transform.test_consume_operand_if_matches_param_or_fail %0[42]
+ "transform.test_branching_transform_op_terminator"()[^bb2, ^bb3] : () -> ()
+ ^bb2:
+ "transform.test_branching_transform_op_terminator"()[^bb1] : () -> ()
+ ^bb3:
+ "transform.test_branching_transform_op_terminator"() : () -> ()
+ }
+ return
+}
+