[mlir] Use-after-free checker for the Transform dialect
authorAlex Zinenko <zinenko@google.com>
Wed, 25 May 2022 13:57:22 +0000 (15:57 +0200)
committerAlex Zinenko <zinenko@google.com>
Thu, 26 May 2022 10:28:41 +0000 (12:28 +0200)
The Transform dialect uses the side effect modeling mechanism to record the
effects of the transform ops on the mapping between Transform IR values and
Payload IR ops. Introduce a checker pass that warns if a Transform IR value is
used after it has been freed (consumed). This pass is mostly intended as a
debugging aid in addition to the verification/assertion mechanisms in the
transform interpreter. It reports all potential use-after-free situations.
The implementation makes a series of simplifying assumptions to be simple and
conservative. A more advanced implementation would rely on the data flow-like
analysis associated with a side-effect resource rather than a value, which is
currently not supported by the analysis infrastructure.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D126381

15 files changed:
mlir/include/mlir/Dialect/Transform/CMakeLists.txt
mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/include/mlir/Dialect/Transform/Transforms/CMakeLists.txt [new file with mode: 0644]
mlir/include/mlir/Dialect/Transform/Transforms/Passes.h [new file with mode: 0644]
mlir/include/mlir/Dialect/Transform/Transforms/Passes.td [new file with mode: 0644]
mlir/include/mlir/InitAllPasses.h
mlir/lib/Dialect/Transform/CMakeLists.txt
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp [new file with mode: 0644]
mlir/test/Dialect/Transform/check-use-after-free.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index 9714b77..70ef677 100644 (file)
@@ -14,6 +14,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Transform/IR/TransformOps.h.inc"
index 2548492..0f1db45 100644 (file)
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS
 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS
 
+include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Dialect/PDL/IR/PDLTypes.td"
@@ -48,7 +49,10 @@ def PDLMatchOp : TransformDialectOp<"pdl_match",
 }
 
 def SequenceOp : TransformDialectOp<"sequence",
-    [DeclareOpInterfaceMethods<TransformOpInterface>,
+    [DeclareOpInterfaceMethods<RegionBranchOpInterface,
+        ["getSuccessorEntryOperands", "getSuccessorRegions",
+         "getRegionInvocationBounds"]>,
+     DeclareOpInterfaceMethods<TransformOpInterface>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      OpAsmOpInterface, PossibleTopLevelTransformOpTrait,
      SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/Transforms/CMakeLists.txt
new file mode 100644 (file)
index 0000000..3a399e6
--- /dev/null
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Transform)
+add_public_tablegen_target(MLIRTransformDialectTransformsIncGen)
+
+add_mlir_doc(Passes TransformPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/Passes.h b/mlir/include/mlir/Dialect/Transform/Transforms/Passes.h
new file mode 100644 (file)
index 0000000..a567db6
--- /dev/null
@@ -0,0 +1,26 @@
+//===- CheckUses.h - Expensive transform value validity checks --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_TRANSFORM_TRANSFORMS_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+namespace transform {
+std::unique_ptr<Pass> createCheckUsesPass();
+
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td b/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td
new file mode 100644 (file)
index 0000000..1d7d865
--- /dev/null
@@ -0,0 +1,36 @@
+//===-- Passes.td - Transform dialect pass definitions -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_TRANSFORMS_PASSES
+#define MLIR_DIALECT_TRANSFORM_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def CheckUses : Pass<"transform-dialect-check-uses"> {
+  let summary = "warn about potential use-after-free in the transform dialect";
+  let description = [{
+    This pass analyzes operations from the transform dialect and its extensions
+    and warns if a transform IR value may be used by an operation after it was
+    "freed" by some other operation, as described by side effects on the
+    `TransformMappingResource`. This statically detects situations that lead to
+    errors when interpreting the Transform IR.
+
+    The pass is capable of handling branching control flow and reports all
+    _potential_ use-after-free situations, e.g., a may-use-after-free is
+    reported if at least one of the control flow paths between the definition of
+    a value and its use contains an operation with a "free" effect on the
+    `TransformMappingResource`. It does not currently perform an SCCP-style data
+    flow analysis to prove that some branches are not taken, however, SCCP and
+    other control flow simplifications can be performed on the transform IR
+    prior to this pass provided that transform ops implement the relevant
+    control flow interfaces.
+  }];
+  let constructor = "::mlir::transform::createCheckUsesPass()";
+}
+
+#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_PASSES
index ccbb74d..d6aacb4 100644 (file)
@@ -32,6 +32,7 @@
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Transform/Transforms/Passes.h"
 #include "mlir/Dialect/Vector/Transforms/Passes.h"
 #include "mlir/Transforms/Passes.h"
 
@@ -72,6 +73,7 @@ inline void registerAllPasses() {
   spirv::registerSPIRVPasses();
   tensor::registerTensorPasses();
   tosa::registerTosaOptPasses();
+  transform::registerTransformPasses();
   vector::registerVectorPasses();
 
   // Dialect pipelines
index 3dcf414..fa3ba7e 100644 (file)
@@ -13,6 +13,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Rewrite/PatternApplicator.h"
 #include "llvm/ADT/ScopeExit.h"
@@ -289,6 +290,35 @@ void transform::SequenceOp::getEffects(
   }
 }
 
+OperandRange transform::SequenceOp::getSuccessorEntryOperands(unsigned index) {
+  assert(index == 0 && "unexpected region index");
+  if (getOperation()->getNumOperands() == 1)
+    return getOperation()->getOperands();
+  return OperandRange(getOperation()->operand_end(),
+                      getOperation()->operand_end());
+}
+
+void transform::SequenceOp::getSuccessorRegions(
+    Optional<unsigned> index, ArrayRef<Attribute> operands,
+    SmallVectorImpl<RegionSuccessor> &regions) {
+  if (!index.hasValue()) {
+    Region *bodyRegion = &getBody();
+    regions.emplace_back(bodyRegion, !operands.empty()
+                                         ? bodyRegion->getArguments()
+                                         : Block::BlockArgListType());
+    return;
+  }
+
+  assert(*index == 0 && "unexpected region index");
+  regions.emplace_back(getOperation()->getResults());
+}
+
+void transform::SequenceOp::getRegionInvocationBounds(
+    ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
+  (void)operands;
+  bounds.emplace_back(1, 1);
+}
+
 //===----------------------------------------------------------------------===//
 // WithPDLPatternsOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
new file mode 100644 (file)
index 0000000..68ad95b
--- /dev/null
@@ -0,0 +1,11 @@
+add_mlir_dialect_library(MLIRTransformDialectTransforms
+  CheckUses.cpp
+
+  DEPENDS
+  MLIRTransformDialectTransformsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRTransformDialect
+  MLIRIR
+  MLIRPass
+)
diff --git a/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp b/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp
new file mode 100644 (file)
index 0000000..a325a62
--- /dev/null
@@ -0,0 +1,402 @@
+//===- CheckUses.cpp - Expensive transform value validity checks ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a pass that performs expensive opt-in checks for Transform
+// dialect values being potentially used after they have been consumed.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlowAnalysis.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/Transforms/Passes.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/SetOperations.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Returns a reference to a cached set of blocks that are reachable from the
+/// given block via edges computed by the `getNextNodes` function. For example,
+/// if `getNextNodes` returns successors of a block, this will return the set of
+/// reachable blocks; if it returns predecessors of a block, this will return
+/// the set of blocks from which the given block can be reached. The block is
+/// considered reachable form itself only if there is a cycle.
+template <typename FnTy>
+const llvm::SmallPtrSet<Block *, 4> &
+getReachableImpl(Block *block, FnTy getNextNodes,
+                 DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> &cache) {
+  auto it = cache.find(block);
+  if (it != cache.end())
+    return it->getSecond();
+
+  llvm::SmallPtrSet<Block *, 4> &reachable = cache[block];
+  SmallVector<Block *> worklist;
+  worklist.push_back(block);
+  while (!worklist.empty()) {
+    Block *current = worklist.pop_back_val();
+    for (Block *predecessor : getNextNodes(current)) {
+      // The block is reachable from its transitive predecessors. Only add
+      // them to the worklist if they weren't already visited.
+      if (reachable.insert(predecessor).second)
+        worklist.push_back(predecessor);
+    }
+  }
+  return reachable;
+}
+
+/// An analysis that identifies whether a value allocated by a Transform op may
+/// be used by another such op after it may have been freed by a third op on
+/// some control flow path. This is conceptually similar to a data flow
+/// analysis, but relies on side effects related to particular values that
+/// currently cannot be modeled by the MLIR data flow analysis framework (also,
+/// the lattice element would be rather expensive as it would need to include
+/// live and/or freed values for each operation).
+///
+/// This analysis is conservatively pessimisic: it will consider that a value
+/// may be freed if it is freed on any possible control flow path between its
+/// allocation and a relevant use, even if the control never actually flows
+/// through the operation that frees the value. It also does not differentiate
+/// between may- (freed on at least one control flow path) and must-free (freed
+/// on all possible control flow paths) because it would require expensive graph
+/// algorithms.
+///
+/// It is intended as an additional non-blocking verification or debugging aid
+/// for ops in the Transform dialect. It leverages the requirement for Transform
+/// dialect ops to implement the MemoryEffectsOpInterface, and expects the
+/// values in the Transform IR to have an allocation effect on the
+/// TransformMappingResource when defined.
+class TransformOpMemFreeAnalysis {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransformOpMemFreeAnalysis)
+
+  /// Computes the analysis for Transform ops nested in the given operation.
+  explicit TransformOpMemFreeAnalysis(Operation *root) {
+    root->walk([&](Operation *op) {
+      if (isa<transform::TransformOpInterface>(op)) {
+        collectFreedValues(op);
+        return WalkResult::skip();
+      }
+      return WalkResult::advance();
+    });
+  }
+
+  /// A list of operations that may be deleting a value. Non-empty list
+  /// contextually converts to boolean "true" value.
+  class PotentialDeleters {
+  public:
+    /// Creates an empty list that corresponds to the value being live.
+    static PotentialDeleters live() { return PotentialDeleters({}); }
+
+    /// Creates a list from the operations that may be deleting the value.
+    static PotentialDeleters maybeFreed(ArrayRef<Operation *> deleters) {
+      return PotentialDeleters(deleters);
+    }
+
+    /// Converts to "true" if there are operations that may be deleting the
+    /// value.
+    explicit operator bool() const { return !deleters.empty(); }
+
+    /// Concatenates the lists of operations that may be deleting the value. The
+    /// value is known to be live if the reuslting list is still empty.
+    PotentialDeleters &operator|=(const PotentialDeleters &other) {
+      llvm::append_range(deleters, other.deleters);
+      return *this;
+    }
+
+    /// Returns the list of ops that may be deleting the value.
+    ArrayRef<Operation *> getOps() const { return deleters; }
+
+  private:
+    /// Constructs the list from the given operations.
+    explicit PotentialDeleters(ArrayRef<Operation *> ops) {
+      llvm::append_range(deleters, ops);
+    }
+
+    /// The list of operations that may be deleting the value.
+    SmallVector<Operation *> deleters;
+  };
+
+  /// Returns the list of operations that may be deleting the operand value on
+  /// any control flow path between the definition of the value and its use as
+  /// the given operand. For the purposes of this analysis, the value is
+  /// considered to be allocated at its definition point and never re-allocated.
+  PotentialDeleters isUseLive(OpOperand &operand) {
+    const llvm::SmallPtrSet<Operation *, 2> &deleters = freedBy[operand.get()];
+    if (deleters.empty())
+      return live();
+
+#ifndef NDEBUG
+    // Check that the definition point actually allcoates the value.
+    Operation *valueSource =
+        operand.get().isa<OpResult>()
+            ? operand.get().getDefiningOp()
+            : operand.get().getParentBlock()->getParentOp();
+    auto iface = cast<MemoryEffectOpInterface>(valueSource);
+    SmallVector<MemoryEffects::EffectInstance> instances;
+    iface.getEffectsOnResource(transform::TransformMappingResource::get(),
+                               instances);
+    assert(hasEffect<MemoryEffects::Allocate>(instances, operand.get()) &&
+           "expected the op defining the value to have an allocation effect "
+           "on it");
+#endif
+
+    // Collect ancestors of the use operation.
+    Block *defBlock = operand.get().getParentBlock();
+    SmallVector<Operation *> ancestors;
+    Operation *ancestor = operand.getOwner();
+    do {
+      ancestors.push_back(ancestor);
+      if (ancestor->getParentRegion() == defBlock->getParent())
+        break;
+      ancestor = ancestor->getParentOp();
+    } while (true);
+    std::reverse(ancestors.begin(), ancestors.end());
+
+    // Consider the control flow from the definition point of the value to its
+    // use point. If the use is located in some nested region, consider the path
+    // from the entry block of the region to the use.
+    for (Operation *ancestor : ancestors) {
+      // The block should be considered partially if it is the block that
+      // contains the definition (allocation) of the value being used, and the
+      // value is defined in the middle of the block, i.e., is not a block
+      // argument.
+      bool isOutermost = ancestor == ancestors.front();
+      bool isFromBlockPartial = isOutermost && operand.get().isa<OpResult>();
+
+      // Check if the value may be freed by operations between its definition
+      // (allocation) point in its block and the terminator of the block or the
+      // ancestor of the use if it is located in the same block. This is only
+      // done for partial blocks here, full blocks will be considered below
+      // similarly to other blocks.
+      if (isFromBlockPartial) {
+        bool defUseSameBlock = ancestor->getBlock() == defBlock;
+        // Consider all ops from the def to its block terminator, except the
+        // when the use is in the same block, in which case only consider the
+        // ops until the user.
+        if (PotentialDeleters potentialDeleters = isFreedInBlockAfter(
+                operand.get().getDefiningOp(), operand.get(),
+                defUseSameBlock ? ancestor : nullptr))
+          return potentialDeleters;
+      }
+
+      // Check if the value may be freed by opeations preceding the ancestor in
+      // its block. Skip the check for partial blocks that contain both the
+      // definition and the use point, as this has been already checked above.
+      if (!isFromBlockPartial || ancestor->getBlock() != defBlock) {
+        if (PotentialDeleters potentialDeleters =
+                isFreedInBlockBefore(ancestor, operand.get()))
+          return potentialDeleters;
+      }
+
+      // Check if the value may be freed by operations in any of the blocks
+      // between the definition point (in the outermost region) or the entry
+      // block of the region (in other regions) and the operand or its ancestor
+      // in the region. This includes the entire "form" block if (1) the block
+      // has not been considered as partial above and (2) the block can be
+      // reached again through some control-flow loop. This includes the entire
+      // "to" block if it can be reached form itself through some control-flow
+      // cycle, regardless of whether it has been visited before.
+      Block *ancestorBlock = ancestor->getBlock();
+      Block *from =
+          isOutermost ? defBlock : &ancestorBlock->getParent()->front();
+      if (PotentialDeleters potentialDeleters =
+              isMaybeFreedOnPaths(from, ancestorBlock, operand.get(),
+                                  /*alwaysIncludeFrom=*/!isFromBlockPartial))
+        return potentialDeleters;
+    }
+    return live();
+  }
+
+private:
+  /// Make PotentialDeleters constructors available with shorter names.
+  static PotentialDeleters maybeFreed(ArrayRef<Operation *> deleters) {
+    return PotentialDeleters::maybeFreed(deleters);
+  }
+  static PotentialDeleters live() { return PotentialDeleters::live(); }
+
+  /// Returns the list of operations that may be deleting the given value betwen
+  /// the first and last operations, non-inclusive. `getNext` indicates the
+  /// direction of the traversal.
+  PotentialDeleters
+  isFreedBetween(Value value, Operation *first, Operation *last,
+                 llvm::function_ref<Operation *(Operation *)> getNext) const {
+    auto it = freedBy.find(value);
+    if (it == freedBy.end())
+      return live();
+    const llvm::SmallPtrSet<Operation *, 2> &deleters = it->getSecond();
+    for (Operation *op = getNext(first); op != last; op = getNext(op)) {
+      if (deleters.contains(op))
+        return maybeFreed(op);
+    }
+    return live();
+  }
+
+  /// Returns the list of operations that may be deleting the given value
+  /// between `root` and `before` values. `root` is expected to be in the same
+  /// block as `before` and precede it. If `before` is null, consider all
+  /// operations until the end of the block including the terminator.
+  PotentialDeleters isFreedInBlockAfter(Operation *root, Value value,
+                                        Operation *before = nullptr) const {
+    return isFreedBetween(value, root, before,
+                          [](Operation *op) { return op->getNextNode(); });
+  }
+
+  /// Returns the list of operations that may be deleting the given value
+  /// between the entry of the block and the `root` operation.
+  PotentialDeleters isFreedInBlockBefore(Operation *root, Value value) const {
+    return isFreedBetween(value, root, nullptr,
+                          [](Operation *op) { return op->getPrevNode(); });
+  }
+
+  /// Returns the list of operations that may be deleting the given value on
+  /// any of the control flow paths between the "form" and the "to" block. The
+  /// operations from any block visited on any control flow path are
+  /// consdiered. The "from" block is considered if there is a control flow
+  /// cycle going through it, i.e., if there is a possibility that all
+  /// operations in this block are visited or if the `alwaysIncludeFrom` flag is
+  /// set. The "to" block is considered only if there is a control flow cycle
+  /// going through it.
+  PotentialDeleters isMaybeFreedOnPaths(Block *from, Block *to, Value value,
+                                        bool alwaysIncludeFrom) {
+    // Find all blocks that lie on any path between "from" and "to", i.e., the
+    // intersection of blocks reachable from "from" and blocks from which "to"
+    // is rechable.
+    const llvm::SmallPtrSet<Block *, 4> &sources = getReachableFrom(to);
+    if (!sources.contains(from))
+      return live();
+
+    llvm::SmallPtrSet<Block *, 4> reachable(getReachable(from));
+    llvm::set_intersect(reachable, sources);
+
+    // If requested, include the "from" block that may not be present in the set
+    // of visited blocks when there is no cycle going through it.
+    if (alwaysIncludeFrom)
+      reachable.insert(from);
+
+    // Join potential deleters from all blocks as we don't know here which of
+    // the paths through the control flow is taken.
+    PotentialDeleters potentialDeleters = live();
+    for (Block *block : reachable) {
+      for (Operation &op : *block) {
+        if (freedBy[value].count(&op))
+          potentialDeleters |= maybeFreed(&op);
+      }
+    }
+    return potentialDeleters;
+  }
+
+  /// Popualtes `reachable` with the set of blocks that are rechable from the
+  /// given block. A block is considered reachable from itself if there is a
+  /// cycle in the control-flow graph that invovles the block.
+  const llvm::SmallPtrSet<Block *, 4> &getReachable(Block *block) {
+    return getReachableImpl(
+        block, [](Block *b) { return b->getSuccessors(); }, reachableCache);
+  }
+
+  /// Populates `sources` with the set of blocks from which the given block is
+  /// reachable.
+  const llvm::SmallPtrSet<Block *, 4> &getReachableFrom(Block *block) {
+    return getReachableImpl(
+        block, [](Block *b) { return b->getPredecessors(); },
+        reachableFromCache);
+  }
+
+  /// Returns true of `instances` contains an effect of `EffectTy` on `value`.
+  template <typename EffectTy>
+  static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> instances,
+                        Value value) {
+    return llvm::any_of(instances,
+                        [&](const MemoryEffects::EffectInstance &instance) {
+                          return instance.getValue() == value &&
+                                 isa<EffectTy>(instance.getEffect());
+                        });
+  }
+
+  /// Records the values that are being freed by an operation or any of its
+  /// children in `freedBy`.
+  void collectFreedValues(Operation *root) {
+    SmallVector<MemoryEffects::EffectInstance> instances;
+    root->walk([&](Operation *child) {
+      // TODO: extend this to conservatively handle operations with undeclared
+      // side effects as maybe freeing the operands.
+      auto iface = cast<MemoryEffectOpInterface>(child);
+      instances.clear();
+      iface.getEffectsOnResource(transform::TransformMappingResource::get(),
+                                 instances);
+      for (Value operand : child->getOperands()) {
+        if (hasEffect<MemoryEffects::Free>(instances, operand)) {
+          // All parents of the operation that frees a value should be
+          // considered as potentially freeing the value as well.
+          //
+          // TODO: differentiate between must-free/may-free as well as between
+          // this op having the effect and children having the effect. This may
+          // require some analysis of all control flow paths through the nested
+          // regions as well as a mechanism to separate proper side effects from
+          // those obtained by nesting.
+          Operation *parent = child;
+          do {
+            freedBy[operand].insert(parent);
+            if (parent == root)
+              break;
+            parent = parent->getParentOp();
+          } while (true);
+        }
+      }
+    });
+  }
+
+  /// The mapping from a value to operations that have a Free memory effect on
+  /// the TransformMappingResource and associated with this value, or to
+  /// Transform operations transitively containing such operations.
+  DenseMap<Value, llvm::SmallPtrSet<Operation *, 2>> freedBy;
+
+  /// Caches for sets of reachable blocks.
+  DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> reachableCache;
+  DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> reachableFromCache;
+};
+
+#define GEN_PASS_CLASSES
+#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
+
+//// A simple pass that warns about any use of a value by a transform operation
+// that may be using the value after it has been freed.
+class CheckUsesPass : public CheckUsesBase<CheckUsesPass> {
+public:
+  void runOnOperation() override {
+    auto &analysis = getAnalysis<TransformOpMemFreeAnalysis>();
+
+    getOperation()->walk([&](Operation *child) {
+      for (OpOperand &operand : child->getOpOperands()) {
+        TransformOpMemFreeAnalysis::PotentialDeleters deleters =
+            analysis.isUseLive(operand);
+        if (!deleters)
+          continue;
+
+        InFlightDiagnostic diag = child->emitWarning()
+                                  << "operand #" << operand.getOperandNumber()
+                                  << " may be used after free";
+        diag.attachNote(operand.get().getLoc()) << "allocated here";
+        for (Operation *d : deleters.getOps()) {
+          diag.attachNote(d->getLoc()) << "freed here";
+        }
+      }
+    });
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace transform {
+std::unique_ptr<Pass> createCheckUsesPass() {
+  return std::make_unique<CheckUsesPass>();
+}
+} // namespace transform
+} // namespace mlir
diff --git a/mlir/test/Dialect/Transform/check-use-after-free.mlir b/mlir/test/Dialect/Transform/check-use-after-free.mlir
new file mode 100644 (file)
index 0000000..6ef865a
--- /dev/null
@@ -0,0 +1,169 @@
+// RUN: mlir-opt %s --transform-dialect-check-uses --split-input-file --verify-diagnostics
+
+func.func @use_after_free_branching_control_flow() {
+  // expected-note @below {{allocated here}}
+  %0 = transform.test_produce_param_or_forward_operand 42
+  transform.test_transform_op_with_regions {
+    "transform.test_branching_transform_op_terminator"() : () -> ()
+  },
+  {
+  ^bb0:
+    "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> ()
+  ^bb1:
+    // expected-note @below {{freed here}}
+    transform.test_consume_operand_if_matches_param_or_fail %0[42]
+    "transform.test_branching_transform_op_terminator"()[^bb3] : () -> ()
+  ^bb2:
+    "transform.test_branching_transform_op_terminator"()[^bb3] : () -> ()
+  ^bb3:
+    // expected-warning @below {{operand #0 may be used after free}}
+    transform.sequence %0 {
+    ^bb0(%arg0: !pdl.operation):
+    }
+    "transform.test_branching_transform_op_terminator"() : () -> ()
+  }
+  return
+}
+
+// -----
+
+func.func @use_after_free_in_nested_op() {
+  // expected-note @below {{allocated here}}
+  %0 = transform.test_produce_param_or_forward_operand 42
+  // expected-note @below {{freed here}}
+  transform.test_transform_op_with_regions {
+    "transform.test_branching_transform_op_terminator"() : () -> ()
+  },
+  {
+  ^bb0:
+    "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> ()
+  ^bb1:
+    transform.test_consume_operand_if_matches_param_or_fail %0[42]
+    "transform.test_branching_transform_op_terminator"()[^bb3] : () -> ()
+  ^bb2:
+    "transform.test_branching_transform_op_terminator"()[^bb3] : () -> ()
+  ^bb3:
+    "transform.test_branching_transform_op_terminator"() : () -> ()
+  }
+  // expected-warning @below {{operand #0 may be used after free}}
+  transform.sequence %0 {
+    ^bb0(%arg0: !pdl.operation):
+  }
+  return
+}
+
+// -----
+
+func.func @use_after_free_recursive_side_effects() {
+  transform.sequence {
+  ^bb0(%arg0: !pdl.operation):
+    // expected-note @below {{allocated here}}
+    %0 = transform.sequence %arg0 attributes { ord = 1 } {
+    ^bb1(%arg1: !pdl.operation):
+      yield %arg1 : !pdl.operation
+    } : !pdl.operation
+    transform.sequence %0 attributes { ord = 2 } {
+    ^bb2(%arg2: !pdl.operation):
+    }
+    transform.sequence %0 attributes { ord = 3 } {
+    ^bb3(%arg3: !pdl.operation):
+    }
+    
+    // `transform.sequence` has recursive side effects so it has the same "free"
+    // as the child op it contains.
+    // expected-note @below {{freed here}}
+    transform.sequence %0 attributes { ord = 4 } {
+    ^bb4(%arg4: !pdl.operation):
+      test_consume_operand_if_matches_param_or_fail %0[42]
+    }
+    // expected-warning @below {{operand #0 may be used after free}}
+    transform.sequence %0 attributes { ord = 5 } {
+    ^bb3(%arg3: !pdl.operation):
+    }
+  }
+  return
+}
+
+// -----
+
+func.func @use_after_free() {
+  transform.sequence {
+  ^bb0(%arg0: !pdl.operation):
+    // expected-note @below {{allocated here}}
+    %0 = transform.sequence %arg0 attributes { ord = 1 } {
+    ^bb1(%arg1: !pdl.operation):
+      yield %arg1 : !pdl.operation
+    } : !pdl.operation
+    transform.sequence %0 attributes { ord = 2 } {
+    ^bb2(%arg2: !pdl.operation):
+    }
+    transform.sequence %0 attributes { ord = 3 } {
+    ^bb3(%arg3: !pdl.operation):
+    }
+    
+    // expected-note @below {{freed here}}
+    test_consume_operand_if_matches_param_or_fail %0[42]
+    // expected-warning @below {{operand #0 may be used after free}}
+    transform.sequence %0 attributes { ord = 5 } {
+    ^bb3(%arg3: !pdl.operation):
+    }
+  }
+  return
+}
+
+// -----
+
+// In the case of a control flow cycle, the operation that uses the value may
+// precede the one that frees it in the same block. Both operations should
+// be reported as use-after-free.
+func.func @use_after_free_self_cycle() {
+  // expected-note @below {{allocated here}}
+  %0 = transform.test_produce_param_or_forward_operand 42
+  transform.test_transform_op_with_regions {
+    "transform.test_branching_transform_op_terminator"() : () -> ()
+  },
+  {
+  ^bb0:
+    "transform.test_branching_transform_op_terminator"()[^bb1] : () -> ()
+  ^bb1:
+    // expected-warning @below {{operand #0 may be used after free}}
+    transform.sequence %0 {
+    ^bb0(%arg0: !pdl.operation):
+    }
+    // expected-warning @below {{operand #0 may be used after free}}
+    // expected-note @below {{freed here}}
+    transform.test_consume_operand_if_matches_param_or_fail %0[42]
+    "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> ()
+  ^bb2:
+    "transform.test_branching_transform_op_terminator"() : () -> ()
+  }
+  return
+}
+
+
+// -----
+
+// Check that the "free" that happens in a cycle is also reported as potential
+// use-after-free.
+func.func @use_after_free_cycle() {
+  // expected-note @below {{allocated here}}
+  %0 = transform.test_produce_param_or_forward_operand 42
+  transform.test_transform_op_with_regions {
+    "transform.test_branching_transform_op_terminator"() : () -> ()
+  },
+  {
+  ^bb0:
+    "transform.test_branching_transform_op_terminator"()[^bb1, ^bb2] : () -> ()
+  ^bb1:
+    // expected-warning @below {{operand #0 may be used after free}}
+    // expected-note @below {{freed here}}
+    transform.test_consume_operand_if_matches_param_or_fail %0[42]
+    "transform.test_branching_transform_op_terminator"()[^bb2, ^bb3] : () -> ()
+  ^bb2:
+    "transform.test_branching_transform_op_terminator"()[^bb1] : () -> ()
+  ^bb3:
+    "transform.test_branching_transform_op_terminator"() : () -> ()
+  }
+  return
+}
+
index c9687fa..39a3afc 100644 (file)
@@ -184,6 +184,21 @@ LogicalResult mlir::test::TestRemoveTestExtensionOp::apply(
   state.removeExtension<TestTransformStateExtension>();
   return success();
 }
+LogicalResult mlir::test::TestTransformOpWithRegions::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  return success();
+}
+
+void mlir::test::TestTransformOpWithRegions::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
+
+LogicalResult mlir::test::TestBranchingTransformOpTerminator::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  return success();
+}
+
+void mlir::test::TestBranchingTransformOpTerminator::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
 
 namespace {
 /// Test extension of the Transform dialect. Registers additional ops and
index d33f790..8623b8a 100644 (file)
@@ -92,5 +92,21 @@ def TestRemoveTestExtensionOp
   let cppNamespace = "::mlir::test";
 }
 
+def TestTransformOpWithRegions
+  : Op<Transform_Dialect, "test_transform_op_with_regions",
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let regions = (region AnyRegion:$first, AnyRegion:$second);
+  let assemblyFormat = "attr-dict-with-keyword regions";
+  let cppNamespace = "::mlir::test";
+}
+
+def TestBranchingTransformOpTerminator
+  : Op<Transform_Dialect, "test_branching_transform_op_terminator",
+       [Terminator, DeclareOpInterfaceMethods<TransformOpInterface>,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let successors = (successor VariadicSuccessor<AnySuccessor>:$succ);
+  let cppNamespace = "::mlir::test";
+}
 
 #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
index 00fec62..61171f3 100644 (file)
@@ -6266,6 +6266,7 @@ cc_library(
         ":TosaDialect",
         ":TosaToLinalg",
         ":TransformDialect",
+        ":TransformDialectTransforms",
         ":Transforms",
         ":TransformsPassIncGen",
         ":VectorOps",
@@ -7876,6 +7877,7 @@ td_library(
     name = "TransformDialectTdFiles",
     srcs = glob(["include/mlir/Dialect/Transform/IR/*.td"]),
     deps = [
+        ":ControlFlowInterfacesTdFiles",
         ":OpBaseTdFiles",
         ":PDLDialectTdFiles",
         ":SideEffectInterfacesTdFiles",
@@ -7949,6 +7951,7 @@ cc_library(
     srcs = glob(["lib/Dialect/Transform/IR/*.cpp"]),
     hdrs = glob(["include/mlir/Dialect/Transform/IR/*.h"]),
     deps = [
+        ":ControlFlowInterfaces",
         ":IR",
         ":PDLDialect",
         ":PDLInterpDialect",
@@ -7963,6 +7966,47 @@ cc_library(
 )
 
 td_library(
+    name = "TransformDialectTransformsTdFiles",
+    srcs = glob(["include/mlir/Dialect/Transform/Transforms/*.td"]),
+    deps = [
+        ":PassBaseTdFiles",
+        ":TransformDialectTdFiles",
+    ],
+)
+
+gentbl_cc_library(
+    name = "TransformDialectTransformsIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            [
+                "-gen-pass-decls",
+                "-name=Transform",
+            ],
+            "include/mlir/Dialect/Transform/Transforms/Passes.h.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Transform/Transforms/Passes.td",
+    deps = [":TransformDialectTransformsTdFiles"],
+)
+
+cc_library(
+    name = "TransformDialectTransforms",
+    srcs = glob(["lib/Dialect/Transform/Transforms/*.cpp"]),
+    hdrs = glob(["include/mlir/Dialect/Transform/Transforms/*.h"]),
+    deps = [
+        ":Analysis",
+        ":IR",
+        ":Pass",
+        ":SideEffectInterfaces",
+        ":TransformDialect",
+        ":TransformDialectTransformsIncGen",
+        "//llvm:Support",
+    ],
+)
+
+td_library(
     name = "ComplexOpsTdFiles",
     srcs = [
         "include/mlir/Dialect/Complex/IR/ComplexBase.td",