[mlir] Add support for querying the ModRef behavior from the AliasAnalysis class
authorRiver Riddle <riddleriver@gmail.com>
Thu, 27 May 2021 20:47:52 +0000 (13:47 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Thu, 27 May 2021 20:57:29 +0000 (13:57 -0700)
This allows for checking if a given operation may modify/reference/or both a given value. Right now this API is limited to Value based memory locations, but we should expand this to include attribute based values at some point. This is left for future work because the rest of the AliasAnalysis API also has this restriction.

Differential Revision: https://reviews.llvm.org/D101673

mlir/include/mlir/Analysis/AliasAnalysis.h
mlir/include/mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h
mlir/lib/Analysis/AliasAnalysis.cpp
mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
mlir/test/Analysis/test-alias-analysis-modref.mlir [new file with mode: 0644]
mlir/test/lib/Analysis/TestAliasAnalysis.cpp

index f3fce42..925af24 100644 (file)
@@ -67,14 +67,106 @@ public:
   /// Returns if this result is a partial alias.
   bool isPartial() const { return kind == PartialAlias; }
 
-  /// Return the internal kind of this alias result.
-  Kind getKind() const { return kind; }
+  /// Print this alias result to the provided output stream.
+  void print(raw_ostream &os) const;
 
 private:
   /// The internal kind of the result.
   Kind kind;
 };
 
+inline raw_ostream &operator<<(raw_ostream &os, const AliasResult &result) {
+  result.print(os);
+  return os;
+}
+
+//===----------------------------------------------------------------------===//
+// ModRefResult
+//===----------------------------------------------------------------------===//
+
+/// The possible results of whether a memory access modifies or references
+/// a memory location. The possible results are: no access at all, a
+/// modification, a reference, or both a modification and a reference.
+class LLVM_NODISCARD ModRefResult {
+  /// Note: This is a simplified version of the ModRefResult in
+  /// `llvm/Analysis/AliasAnalysis.h`, and namely removes the `Must` concept. If
+  /// this becomes useful/necessary we should add it here.
+  enum class Kind {
+    /// The access neither references nor modifies the value stored in memory.
+    NoModRef = 0,
+    /// The access may reference the value stored in memory.
+    Ref = 1,
+    /// The access may modify the value stored in memory.
+    Mod = 2,
+    /// The access may reference and may modify the value stored in memory.
+    ModRef = Ref | Mod,
+  };
+
+public:
+  bool operator==(const ModRefResult &rhs) const { return kind == rhs.kind; }
+  bool operator!=(const ModRefResult &rhs) const { return !(*this == rhs); }
+
+  /// Return a new result that indicates that the memory access neither
+  /// references nor modifies the value stored in memory.
+  static ModRefResult getNoModRef() { return Kind::NoModRef; }
+
+  /// Return a new result that indicates that the memory access may reference
+  /// the value stored in memory.
+  static ModRefResult getRef() { return Kind::Ref; }
+
+  /// Return a new result that indicates that the memory access may modify the
+  /// value stored in memory.
+  static ModRefResult getMod() { return Kind::Mod; }
+
+  /// Return a new result that indicates that the memory access may reference
+  /// and may modify the value stored in memory.
+  static ModRefResult getModAndRef() { return Kind::ModRef; }
+
+  /// Returns if this result does not modify or reference memory.
+  LLVM_NODISCARD bool isNoModRef() const { return kind == Kind::NoModRef; }
+
+  /// Returns if this result modifies memory.
+  LLVM_NODISCARD bool isMod() const {
+    return static_cast<int>(kind) & static_cast<int>(Kind::Mod);
+  }
+
+  /// Returns if this result references memory.
+  LLVM_NODISCARD bool isRef() const {
+    return static_cast<int>(kind) & static_cast<int>(Kind::Ref);
+  }
+
+  /// Returns if this result modifies *or* references memory.
+  LLVM_NODISCARD bool isModOrRef() const { return kind != Kind::NoModRef; }
+
+  /// Returns if this result modifies *and* references memory.
+  LLVM_NODISCARD bool isModAndRef() const { return kind == Kind::ModRef; }
+
+  /// Merge this ModRef result with `other` and return the result.
+  ModRefResult merge(const ModRefResult &other) {
+    return ModRefResult(static_cast<Kind>(static_cast<int>(kind) |
+                                          static_cast<int>(other.kind)));
+  }
+  /// Intersect this ModRef result with `other` and return the result.
+  ModRefResult intersect(const ModRefResult &other) {
+    return ModRefResult(static_cast<Kind>(static_cast<int>(kind) &
+                                          static_cast<int>(other.kind)));
+  }
+
+  /// Print this ModRef result to the provided output stream.
+  void print(raw_ostream &os) const;
+
+private:
+  ModRefResult(Kind kind) : kind(kind) {}
+
+  /// The internal kind of the result.
+  Kind kind;
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, const ModRefResult &result) {
+  result.print(os);
+  return os;
+}
+
 //===----------------------------------------------------------------------===//
 // AliasAnalysisTraits
 //===----------------------------------------------------------------------===//
@@ -92,6 +184,9 @@ struct AliasAnalysisTraits {
 
     /// Given two values, return their aliasing behavior.
     virtual AliasResult alias(Value lhs, Value rhs) = 0;
+
+    /// Return the modify-reference behavior of `op` on `location`.
+    virtual ModRefResult getModRef(Operation *op, Value location) = 0;
   };
 
   /// This class represents the `Model` of an alias analysis implementation
@@ -108,6 +203,11 @@ struct AliasAnalysisTraits {
       return impl.alias(lhs, rhs);
     }
 
+    /// Return the modify-reference behavior of `op` on `location`.
+    ModRefResult getModRef(Operation *op, Value location) final {
+      return impl.getModRef(op, location);
+    }
+
   private:
     ImplT impl;
   };
@@ -147,7 +247,12 @@ public:
   ///   * AnalysisT(AnalysisT &&)
   ///   * AliasResult alias(Value lhs, Value rhs)
   ///     - This method returns an `AliasResult` that corresponds to the
-  ///       aliasing behavior between `lhs` and `rhs`.
+  ///       aliasing behavior between `lhs` and `rhs`. The conservative "I don't
+  ///       know" result of this method should be MayAlias.
+  ///   * ModRefResult getModRef(Operation *op, Value location)
+  ///     - This method returns a `ModRefResult` that corresponds to the
+  ///       modify-reference behavior of `op` on the given `location`. The
+  ///       conservative "I don't know" result of this method should be ModRef.
   template <typename AnalysisT>
   void addAnalysisImplementation(AnalysisT &&analysis) {
     aliasImpls.push_back(
@@ -161,6 +266,13 @@ public:
   /// Given two values, return their aliasing behavior.
   AliasResult alias(Value lhs, Value rhs);
 
+  //===--------------------------------------------------------------------===//
+  // ModRef Queries
+  //===--------------------------------------------------------------------===//
+
+  /// Return the modify-reference behavior of `op` on `location`.
+  ModRefResult getModRef(Operation *op, Value location);
+
 private:
   /// A set of internal alias analysis implementations.
   SmallVector<std::unique_ptr<Concept>, 4> aliasImpls;
index 45edd20..afed185 100644 (file)
@@ -25,6 +25,9 @@ class LocalAliasAnalysis {
 public:
   /// Given two values, return their aliasing behavior.
   AliasResult alias(Value lhs, Value rhs);
+
+  /// Return the modify-reference behavior of `op` on `location`.
+  ModRefResult getModRef(Operation *op, Value location);
 };
 } // end namespace mlir
 
index 946825e..2f2b782 100644 (file)
@@ -27,6 +27,44 @@ AliasResult AliasResult::merge(AliasResult other) const {
   return MayAlias;
 }
 
+void AliasResult::print(raw_ostream &os) const {
+  switch (kind) {
+  case Kind::NoAlias:
+    os << "NoAlias";
+    break;
+  case Kind::MayAlias:
+    os << "MayAlias";
+    break;
+  case Kind::PartialAlias:
+    os << "PartialAlias";
+    break;
+  case Kind::MustAlias:
+    os << "MustAlias";
+    break;
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// ModRefResult
+//===----------------------------------------------------------------------===//
+
+void ModRefResult::print(raw_ostream &os) const {
+  switch (kind) {
+  case Kind::NoModRef:
+    os << "NoModRef";
+    break;
+  case Kind::Ref:
+    os << "Ref";
+    break;
+  case Kind::Mod:
+    os << "Mod";
+    break;
+  case Kind::ModRef:
+    os << "ModRef";
+    break;
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // AliasAnalysis
 //===----------------------------------------------------------------------===//
@@ -35,7 +73,6 @@ AliasAnalysis::AliasAnalysis(Operation *op) {
   addAnalysisImplementation(LocalAliasAnalysis());
 }
 
-/// Given the two values, return their aliasing behavior.
 AliasResult AliasAnalysis::alias(Value lhs, Value rhs) {
   // Check each of the alias analysis implemenations for an alias result.
   for (const std::unique_ptr<Concept> &aliasImpl : aliasImpls) {
@@ -45,3 +82,16 @@ AliasResult AliasAnalysis::alias(Value lhs, Value rhs) {
   }
   return AliasResult::MayAlias;
 }
+
+ModRefResult AliasAnalysis::getModRef(Operation *op, Value location) {
+  // Compute the mod-ref behavior by refining a top `ModRef` result with each of
+  // the alias analysis implementations. We early exit at the point where we
+  // refine down to a `NoModRef`.
+  ModRefResult result = ModRefResult::getModAndRef();
+  for (const std::unique_ptr<Concept> &aliasImpl : aliasImpls) {
+    result = result.intersect(aliasImpl->getModRef(op, location));
+    if (result.isNoModRef())
+      return result;
+  }
+  return result;
+}
index 17a9ded..062443a 100644 (file)
@@ -195,7 +195,7 @@ static void collectUnderlyingAddressValues(Value value,
 }
 
 //===----------------------------------------------------------------------===//
-// LocalAliasAnalysis
+// LocalAliasAnalysis: alias
 //===----------------------------------------------------------------------===//
 
 /// Given a value, try to get an allocation effect attached to it. If
@@ -336,3 +336,56 @@ AliasResult LocalAliasAnalysis::alias(Value lhs, Value rhs) {
   // We should always have a valid result here.
   return *result;
 }
+
+//===----------------------------------------------------------------------===//
+// LocalAliasAnalysis: getModRef
+//===----------------------------------------------------------------------===//
+
+ModRefResult LocalAliasAnalysis::getModRef(Operation *op, Value location) {
+  // Check to see if this operation relies on nested side effects.
+  if (op->hasTrait<OpTrait::HasRecursiveSideEffects>()) {
+    // TODO: To check recursive operations we need to check all of the nested
+    // operations, which can result in a quadratic number of queries. We should
+    // introduce some caching of some kind to help alleviate this, especially as
+    // this caching could be used in other areas of the codebase (e.g. when
+    // checking `wouldOpBeTriviallyDead`).
+    return ModRefResult::getModAndRef();
+  }
+
+  // Otherwise, check to see if this operation has a memory effect interface.
+  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
+  if (!interface)
+    return ModRefResult::getModAndRef();
+
+  // Build a ModRefResult by merging the behavior of the effects of this
+  // operation.
+  SmallVector<MemoryEffects::EffectInstance> effects;
+  interface.getEffects(effects);
+
+  ModRefResult result = ModRefResult::getNoModRef();
+  for (const MemoryEffects::EffectInstance &effect : effects) {
+    if (isa<MemoryEffects::Allocate, MemoryEffects::Free>(effect.getEffect()))
+      continue;
+
+    // Check for an alias between the effect and our memory location.
+    // TODO: Add support for checking an alias with a symbol reference.
+    AliasResult aliasResult = AliasResult::MayAlias;
+    if (Value effectValue = effect.getValue())
+      aliasResult = alias(effectValue, location);
+
+    // If we don't alias, ignore this effect.
+    if (aliasResult.isNo())
+      continue;
+
+    // Merge in the corresponding mod or ref for this effect.
+    if (isa<MemoryEffects::Read>(effect.getEffect())) {
+      result = result.merge(ModRefResult::getRef());
+    } else {
+      assert(isa<MemoryEffects::Write>(effect.getEffect()));
+      result = result.merge(ModRefResult::getMod());
+    }
+    if (result.isModAndRef())
+      break;
+  }
+  return result;
+}
diff --git a/mlir/test/Analysis/test-alias-analysis-modref.mlir b/mlir/test/Analysis/test-alias-analysis-modref.mlir
new file mode 100644 (file)
index 0000000..46ac7fb
--- /dev/null
@@ -0,0 +1,67 @@
+// RUN: mlir-opt %s -pass-pipeline='func(test-alias-analysis-modref)' -split-input-file -allow-unregistered-dialect 2>&1 | FileCheck %s
+
+// CHECK-LABEL: Testing : "no_side_effects"
+// CHECK: alloc -> func.region0#0: NoModRef
+// CHECK: dealloc -> func.region0#0: NoModRef
+// CHECK: return -> func.region0#0: NoModRef
+func @no_side_effects(%arg: memref<2xf32>) attributes {test.ptr = "func"} {
+  %1 = memref.alloc() {test.ptr = "alloc"} : memref<8x64xf32>
+  memref.dealloc %1 {test.ptr = "dealloc"} : memref<8x64xf32>
+  return {test.ptr = "return"}
+}
+
+// -----
+
+// CHECK-LABEL: Testing : "simple"
+// CHECK-DAG: store -> alloc#0: Mod
+// CHECK-DAG: load -> alloc#0: Ref
+
+// CHECK-DAG: store -> func.region0#0: NoModRef
+// CHECK-DAG: load -> func.region0#0: NoModRef
+func @simple(%arg: memref<i32>, %value: i32) attributes {test.ptr = "func"} {
+  %1 = memref.alloca() {test.ptr = "alloc"} : memref<i32>
+  memref.store %value, %1[] {test.ptr = "store"} : memref<i32>
+  %2 = memref.load %1[] {test.ptr = "load"} : memref<i32>
+  return {test.ptr = "return"}
+}
+
+// -----
+
+// CHECK-LABEL: Testing : "mayalias"
+// CHECK-DAG: store -> func.region0#0: Mod
+// CHECK-DAG: load -> func.region0#0: Ref
+
+// CHECK-DAG: store -> func.region0#1: Mod
+// CHECK-DAG: load -> func.region0#1: Ref
+func @mayalias(%arg0: memref<i32>, %arg1: memref<i32>, %value: i32) attributes {test.ptr = "func"} {
+  memref.store %value, %arg1[] {test.ptr = "store"} : memref<i32>
+  %1 = memref.load %arg1[] {test.ptr = "load"} : memref<i32>
+  return {test.ptr = "return"}
+}
+
+// -----
+
+// CHECK-LABEL: Testing : "recursive"
+// CHECK-DAG: if -> func.region0#0: ModRef
+// CHECK-DAG: if -> func.region0#1: ModRef
+
+// TODO: This is provably NoModRef, but requires handling recursive side
+// effects.
+// CHECK-DAG: if -> alloc#0: ModRef
+func @recursive(%arg0: memref<i32>, %arg1: memref<i32>, %cond: i1, %value: i32) attributes {test.ptr = "func"} {
+  %0 = memref.alloca() {test.ptr = "alloc"} : memref<i32>
+  scf.if %cond {
+    memref.store %value, %arg0[] : memref<i32>
+    %1 = memref.load %arg0[] : memref<i32>
+  } {test.ptr = "if"}
+  return {test.ptr = "return"}
+}
+
+// -----
+
+// CHECK-LABEL: Testing : "unknown"
+// CHECK-DAG: unknown -> func.region0#0: ModRef
+func @unknown(%arg0: memref<i32>) attributes {test.ptr = "func"} {
+  "foo.op"() {test.ptr = "unknown"} : () -> ()
+  return
+}
index d17a1c1..c54e5d8 100644 (file)
 
 using namespace mlir;
 
+/// Print a value that is used as an operand of an alias query.
+static void printAliasOperand(Operation *op) {
+  llvm::errs() << op->getAttrOfType<StringAttr>("test.ptr").getValue();
+}
+static void printAliasOperand(Value value) {
+  if (BlockArgument arg = value.dyn_cast<BlockArgument>()) {
+    Region *region = arg.getParentRegion();
+    unsigned parentBlockNumber =
+        std::distance(region->begin(), arg.getOwner()->getIterator());
+    llvm::errs() << region->getParentOp()
+                        ->getAttrOfType<StringAttr>("test.ptr")
+                        .getValue()
+                 << ".region" << region->getRegionNumber();
+    if (parentBlockNumber != 0)
+      llvm::errs() << ".block" << parentBlockNumber;
+    llvm::errs() << "#" << arg.getArgNumber();
+    return;
+  }
+  OpResult result = value.cast<OpResult>();
+  printAliasOperand(result.getOwner());
+  llvm::errs() << "#" << result.getResultNumber();
+}
+
+//===----------------------------------------------------------------------===//
+// Testing AliasResult
+//===----------------------------------------------------------------------===//
+
 namespace {
 struct TestAliasAnalysisPass
     : public PassWrapper<TestAliasAnalysisPass, OperationPass<>> {
   void runOnOperation() override {
-    llvm::errs() << "Testing : ";
-    if (Attribute testName = getOperation()->getAttr("test.name"))
-      llvm::errs() << testName << "\n";
-    else
-      llvm::errs() << getOperation()->getAttr("sym_name") << "\n";
+    llvm::errs() << "Testing : " << getOperation()->getAttr("sym_name") << "\n";
 
     // Collect all of the values to check for aliasing behavior.
     AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
@@ -49,52 +72,64 @@ struct TestAliasAnalysisPass
     printAliasOperand(lhs);
     llvm::errs() << " <-> ";
     printAliasOperand(rhs);
-    llvm::errs() << ": ";
+    llvm::errs() << ": " << result << "\n";
+  }
+};
+} // end anonymous namespace
 
-    switch (result.getKind()) {
-    case AliasResult::NoAlias:
-      llvm::errs() << "NoAlias";
-      break;
-    case AliasResult::MayAlias:
-      llvm::errs() << "MayAlias";
-      break;
-    case AliasResult::PartialAlias:
-      llvm::errs() << "PartialAlias";
-      break;
-    case AliasResult::MustAlias:
-      llvm::errs() << "MustAlias";
-      break;
+//===----------------------------------------------------------------------===//
+// Testing ModRefResult
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct TestAliasAnalysisModRefPass
+    : public PassWrapper<TestAliasAnalysisModRefPass, OperationPass<>> {
+  void runOnOperation() override {
+    llvm::errs() << "Testing : " << getOperation()->getAttr("sym_name") << "\n";
+
+    // Collect all of the values to check for aliasing behavior.
+    AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
+    SmallVector<Value, 32> valsToCheck;
+    getOperation()->walk([&](Operation *op) {
+      if (!op->getAttr("test.ptr"))
+        return;
+      valsToCheck.append(op->result_begin(), op->result_end());
+      for (Region &region : op->getRegions())
+        for (Block &block : region)
+          valsToCheck.append(block.args_begin(), block.args_end());
+    });
+
+    // Check for aliasing behavior between each of the values.
+    for (auto it = valsToCheck.begin(), e = valsToCheck.end(); it != e; ++it) {
+      getOperation()->walk([&](Operation *op) {
+        if (!op->getAttr("test.ptr"))
+          return;
+        printModRefResult(aliasAnalysis.getModRef(op, *it), op, *it);
+      });
     }
-    llvm::errs() << "\n";
   }
-  /// Print a value that is used as an operand of an alias query.
-  void printAliasOperand(Value value) {
-    if (BlockArgument arg = value.dyn_cast<BlockArgument>()) {
-      Region *region = arg.getParentRegion();
-      unsigned parentBlockNumber =
-          std::distance(region->begin(), arg.getOwner()->getIterator());
-      llvm::errs() << region->getParentOp()
-                          ->getAttrOfType<StringAttr>("test.ptr")
-                          .getValue()
-                   << ".region" << region->getRegionNumber();
-      if (parentBlockNumber != 0)
-        llvm::errs() << ".block" << parentBlockNumber;
-      llvm::errs() << "#" << arg.getArgNumber();
-      return;
-    }
-    OpResult result = value.cast<OpResult>();
-    llvm::errs()
-        << result.getOwner()->getAttrOfType<StringAttr>("test.ptr").getValue()
-        << "#" << result.getResultNumber();
+
+  /// Print the result of an alias query.
+  void printModRefResult(ModRefResult result, Operation *op, Value location) {
+    printAliasOperand(op);
+    llvm::errs() << " -> ";
+    printAliasOperand(location);
+    llvm::errs() << ": " << result << "\n";
   }
 };
 } // end anonymous namespace
 
+//===----------------------------------------------------------------------===//
+// Pass Registration
+//===----------------------------------------------------------------------===//
+
 namespace mlir {
 namespace test {
 void registerTestAliasAnalysisPass() {
-  PassRegistration<TestAliasAnalysisPass> pass("test-alias-analysis",
-                                               "Test alias analysis results.");
+  PassRegistration<TestAliasAnalysisPass> aliasPass(
+      "test-alias-analysis", "Test alias analysis results.");
+  PassRegistration<TestAliasAnalysisModRefPass> modRefPass(
+      "test-alias-analysis-modref", "Test alias analysis ModRef results.");
 }
 } // namespace test
 } // namespace mlir