[Attributor] Derive AAFunctionReachability attribute.
authorKuter Dinel <kuterdinel@gmail.com>
Sat, 19 Jun 2021 20:50:11 +0000 (23:50 +0300)
committerKuter Dinel <kuterdinel@gmail.com>
Wed, 23 Jun 2021 17:43:10 +0000 (20:43 +0300)
This attribute uses Attributor's internal 'optimistic' call graph
information to answer queries about function call reachability.

Functions can become reachable over time as new call edges are
discovered.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D104599

llvm/include/llvm/Transforms/IPO/Attributor.h
llvm/lib/Transforms/IPO/AttributorAttributes.cpp
llvm/unittests/Transforms/IPO/AttributorTest.cpp

index 3109dad..5449003 100644 (file)
@@ -4199,6 +4199,39 @@ struct AAExecutionDomain
   static const char ID;
 };
 
+/// An abstract Attribute for computing reachability between functions.
+struct AAFunctionReachability
+    : public StateWrapper<BooleanState, AbstractAttribute> {
+  using Base = StateWrapper<BooleanState, AbstractAttribute>;
+
+  AAFunctionReachability(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+
+  /// If the function represented by this possition can reach \p Fn.
+  virtual bool canReach(Attributor &A, Function *Fn) const = 0;
+
+  /// Create an abstract attribute view for the position \p IRP.
+  static AAFunctionReachability &createForPosition(const IRPosition &IRP,
+                                                   Attributor &A);
+
+  /// See AbstractAttribute::getName()
+  const std::string getName() const override { return "AAFuncitonReacability"; }
+
+  /// See AbstractAttribute::getIdAddr()
+  const char *getIdAddr() const override { return &ID; }
+
+  /// This function should return true if the type of the \p AA is AACallEdges.
+  static bool classof(const AbstractAttribute *AA) {
+    return (AA->getIdAddr() == &ID);
+  }
+
+  /// Unique ID (due to the unique address)
+  static const char ID;
+
+private:
+  /// Can this function reach a call with unknown calee.
+  virtual bool canReachUnknownCallee() const = 0;
+};
+
 /// Run options, used by the pass manager.
 enum AttributorRunOption {
   NONE = 0,
index b5221b1..d6e9e19 100644 (file)
@@ -136,6 +136,7 @@ PIPE_OPERATOR(AAUndefinedBehavior)
 PIPE_OPERATOR(AAPotentialValues)
 PIPE_OPERATOR(AANoUndef)
 PIPE_OPERATOR(AACallEdges)
+PIPE_OPERATOR(AAFunctionReachability)
 
 #undef PIPE_OPERATOR
 } // namespace llvm
@@ -8276,6 +8277,118 @@ struct AACallEdgesFunction : public AACallEdges {
   bool HasUnknownCallee = false;
 };
 
+struct AAFunctionReachabilityFunction : public AAFunctionReachability {
+  AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A)
+      : AAFunctionReachability(IRP, A) {}
+
+  bool canReach(Attributor &A, Function *Fn) const override {
+    // Assume that we can reach any function if we can reach a call with
+    // unknown callee.
+    if (CanReachUnknownCallee)
+      return true;
+
+    if (ReachableQueries.count(Fn))
+      return true;
+
+    if (UnreachableQueries.count(Fn))
+      return false;
+
+    const AACallEdges &AAEdges =
+        A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::REQUIRED);
+
+    const SetVector<Function *> &Edges = AAEdges.getOptimisticEdges();
+    bool Result = checkIfReachable(A, Edges, Fn);
+
+    // Attributor returns attributes as const, so this function has to be
+    // const for users of this attribute to use it without having to do
+    // a const_cast.
+    // This is a hack for us to be able to cache queries.
+    auto *NonConstThis = const_cast<AAFunctionReachabilityFunction *>(this);
+
+    if (Result)
+      NonConstThis->ReachableQueries.insert(Fn);
+    else
+      NonConstThis->UnreachableQueries.insert(Fn);
+
+    return Result;
+  }
+
+  /// See AbstractAttribute::updateImpl(...).
+  ChangeStatus updateImpl(Attributor &A) override {
+    if (CanReachUnknownCallee)
+      return ChangeStatus::UNCHANGED;
+
+    const AACallEdges &AAEdges =
+        A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::REQUIRED);
+    const SetVector<Function *> &Edges = AAEdges.getOptimisticEdges();
+    ChangeStatus Change = ChangeStatus::UNCHANGED;
+
+    if (AAEdges.hasUnknownCallee()) {
+      bool OldCanReachUnknown = CanReachUnknownCallee;
+      CanReachUnknownCallee = true;
+      return OldCanReachUnknown ? ChangeStatus::UNCHANGED
+                                : ChangeStatus::CHANGED;
+    }
+
+    // Check if any of the unreachable functions become reachable.
+    for (auto Current = UnreachableQueries.begin();
+         Current != UnreachableQueries.end();) {
+      if (!checkIfReachable(A, Edges, *Current)) {
+        Current++;
+        continue;
+      }
+      ReachableQueries.insert(*Current);
+      UnreachableQueries.erase(*Current++);
+      Change = ChangeStatus::CHANGED;
+    }
+
+    return Change;
+  }
+
+  const std::string getAsStr() const override {
+    size_t QueryCount = ReachableQueries.size() + UnreachableQueries.size();
+
+    return "FunctionReachability [" + std::to_string(ReachableQueries.size()) +
+           "," + std::to_string(QueryCount) + "]";
+  }
+
+  void trackStatistics() const override {}
+
+private:
+  bool canReachUnknownCallee() const override { return CanReachUnknownCallee; }
+
+  bool checkIfReachable(Attributor &A, const SetVector<Function *> &Edges,
+                        Function *Fn) const {
+    if (Edges.count(Fn))
+      return true;
+
+    for (Function *Edge : Edges) {
+      // We don't need a dependency if the result is reachable.
+      const AAFunctionReachability &EdgeReachability =
+          A.getAAFor<AAFunctionReachability>(*this, IRPosition::function(*Edge),
+                                             DepClassTy::NONE);
+
+      if (EdgeReachability.canReach(A, Fn))
+        return true;
+    }
+    for (Function *Fn : Edges)
+      A.getAAFor<AAFunctionReachability>(*this, IRPosition::function(*Fn),
+                                         DepClassTy::REQUIRED);
+
+    return false;
+  }
+
+  /// Set of functions that we know for sure is reachable.
+  SmallPtrSet<Function *, 8> ReachableQueries;
+
+  /// Set of functions that are unreachable, but might become reachable.
+  SmallPtrSet<Function *, 8> UnreachableQueries;
+
+  /// If we can reach a function with a call to a unknown function we assume
+  /// that we can reach any function.
+  bool CanReachUnknownCallee = false;
+};
+
 } // namespace
 
 AACallGraphNode *AACallEdgeIterator::operator*() const {
@@ -8311,6 +8424,7 @@ const char AAValueConstantRange::ID = 0;
 const char AAPotentialValues::ID = 0;
 const char AANoUndef::ID = 0;
 const char AACallEdges::ID = 0;
+const char AAFunctionReachability::ID = 0;
 
 // Macro magic to create the static generator function for attributes that
 // follow the naming scheme.
@@ -8431,6 +8545,7 @@ CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAHeapToStack)
 CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAReachability)
 CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUndefinedBehavior)
 CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AACallEdges)
+CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAFunctionReachability)
 
 CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryBehavior)
 
index 458b68b..8bbea56 100644 (file)
@@ -73,4 +73,71 @@ TEST_F(AttributorTestBase, TestCast) {
   ASSERT_TRUE(SSucc);
 }
 
+TEST_F(AttributorTestBase, AAReachabilityTest) {
+  const char *ModuleString = R"(
+    declare void @func4()
+    declare void @func3()
+
+    define void @func2() {
+    entry:
+      call void @func3()
+      ret void
+    }
+
+    define void @func1() {
+    entry:
+      call void @func2()
+      ret void
+    }
+
+    define void @func5(void ()* %unknown) {
+    entry:
+      call void %unknown()
+      ret void
+    }
+
+    define void @func6() {
+    entry:
+      call void @func5(void ()* @func3)
+      ret void
+    }
+  )";
+
+  Module &M = parseModule(ModuleString);
+
+  SetVector<Function *> Functions;
+  AnalysisGetter AG;
+  for (Function &F : M)
+    Functions.insert(&F);
+
+  CallGraphUpdater CGUpdater;
+  BumpPtrAllocator Allocator;
+  InformationCache InfoCache(M, AG, Allocator, nullptr);
+  Attributor A(Functions, InfoCache, CGUpdater);
+
+  Function *F1 = M.getFunction("func1");
+  Function *F3 = M.getFunction("func3");
+  Function *F4 = M.getFunction("func4");
+  Function *F6 = M.getFunction("func6");
+
+  const AAFunctionReachability &F1AA =
+      A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(*F1));
+
+  const AAFunctionReachability &F6AA =
+      A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(*F6));
+
+  F1AA.canReach(A, F3);
+  F1AA.canReach(A, F4);
+  F6AA.canReach(A, F4);
+
+  A.run();
+
+  ASSERT_TRUE(F1AA.canReach(A, F3));
+  ASSERT_FALSE(F1AA.canReach(A, F4));
+
+  // Assumed to be reacahable, since F6 can reach a function with
+  // a unknown callee.
+  ASSERT_TRUE(F6AA.canReach(A, F4));
+}
+
 } // namespace llvm