[Attributor] Set up a dedicated simplification call back map for `GlobalVariable`
authorShilei Tian <i@tianshilei.me>
Fri, 21 Apr 2023 04:08:24 +0000 (00:08 -0400)
committerShilei Tian <i@tianshilei.me>
Fri, 21 Apr 2023 04:08:35 +0000 (00:08 -0400)
Currently we don't check call backs for global variable simplification.
What's more, the only way that we can register a simplification call back for
global variable is through its initializer (essentially a `Constant *`). It might
not correspond to the right global variable. In this patch, we set up a dedicated
simplification map for `GlobalVariable`.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D144749

llvm/include/llvm/Transforms/IPO/Attributor.h
llvm/lib/Transforms/IPO/Attributor.cpp

index 3b09635..d4a2ed0 100644 (file)
 #include "llvm/Support/Allocator.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/DOTGraphTraits.h"
+#include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/TimeProfiler.h"
 #include "llvm/TargetParser/Triple.h"
 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
@@ -306,7 +307,7 @@ inline bool operator==(const RangeTy &A, const RangeTy &B) {
 inline bool operator!=(const RangeTy &A, const RangeTy &B) { return !(A == B); }
 
 /// Return the initial value of \p Obj with type \p Ty if that is a constant.
-Constant *getInitialValueForObj(Value &Obj, Type &Ty,
+Constant *getInitialValueForObj(Attributor &A, Value &Obj, Type &Ty,
                                 const TargetLibraryInfo *TLI,
                                 const DataLayout &DL,
                                 RangeTy *RangePtr = nullptr);
@@ -1895,6 +1896,40 @@ struct Attributor {
     return SimplificationCallbacks.count(IRP);
   }
 
+  /// Register \p CB as a simplification callback.
+  /// Similar to \p registerSimplificationCallback, the call back will be called
+  /// first when we simplify a global variable \p GV.
+  using GlobalVariableSimplifictionCallbackTy =
+      std::function<std::optional<Constant *>(
+          const GlobalVariable &, const AbstractAttribute *, bool &)>;
+  void registerGlobalVariableSimplificationCallback(
+      const GlobalVariable &GV,
+      const GlobalVariableSimplifictionCallbackTy &CB) {
+    GlobalVariableSimplificationCallbacks[&GV].emplace_back(CB);
+  }
+
+  /// Return true if there is a simplification callback for \p GV.
+  bool hasGlobalVariableSimplificationCallback(const GlobalVariable &GV) {
+    return GlobalVariableSimplificationCallbacks.count(&GV);
+  }
+
+  /// Return \p std::nullopt if there is no call back registered for \p GV or
+  /// the call back is still not sure if \p GV can be simplified. Return \p
+  /// nullptr if \p GV can't be simplified.
+  std::optional<Constant *>
+  getAssumedInitializerFromCallBack(const GlobalVariable &GV,
+                                    const AbstractAttribute *AA,
+                                    bool &UsedAssumedInformation) {
+    assert(GlobalVariableSimplificationCallbacks.contains(&GV));
+    for (auto &CB : GlobalVariableSimplificationCallbacks.lookup(&GV)) {
+      auto SimplifiedGV = CB(GV, AA, UsedAssumedInformation);
+      // For now we assume the call back will not return a std::nullopt.
+      assert(SimplifiedGV.has_value() && "SimplifiedGV has not value");
+      return *SimplifiedGV;
+    }
+    llvm_unreachable("there must be a callback registered");
+  }
+
   using VirtualUseCallbackTy =
       std::function<bool(Attributor &, const AbstractAttribute *)>;
   void registerVirtualUseCallback(const Value &V,
@@ -1907,6 +1942,12 @@ private:
   DenseMap<IRPosition, SmallVector<SimplifictionCallbackTy, 1>>
       SimplificationCallbacks;
 
+  /// The vector with all simplification callbacks for global variables
+  /// registered by outside AAs.
+  DenseMap<const GlobalVariable *,
+           SmallVector<GlobalVariableSimplifictionCallbackTy, 1>>
+      GlobalVariableSimplificationCallbacks;
+
   DenseMap<const Value *, SmallVector<VirtualUseCallbackTy, 1>>
       VirtualUseCallbacks;
 
index 824587f..4bb1aad 100644 (file)
@@ -221,7 +221,7 @@ bool AA::isDynamicallyUnique(Attributor &A, const AbstractAttribute &QueryingAA,
   return InstanceInfoAA.isAssumedUniqueForAnalysis();
 }
 
-Constant *AA::getInitialValueForObj(Value &Obj, Type &Ty,
+Constant *AA::getInitialValueForObj(Attributor &A, Value &Obj, Type &Ty,
                                     const TargetLibraryInfo *TLI,
                                     const DataLayout &DL,
                                     AA::RangeTy *RangePtr) {
@@ -232,17 +232,31 @@ Constant *AA::getInitialValueForObj(Value &Obj, Type &Ty,
   auto *GV = dyn_cast<GlobalVariable>(&Obj);
   if (!GV)
     return nullptr;
-  if (!GV->hasLocalLinkage() && !(GV->isConstant() && GV->hasInitializer()))
-    return nullptr;
-  if (!GV->hasInitializer())
-    return UndefValue::get(&Ty);
+
+  bool UsedAssumedInformation = false;
+  Constant *Initializer = nullptr;
+  if (A.hasGlobalVariableSimplificationCallback(*GV)) {
+    auto AssumedGV = A.getAssumedInitializerFromCallBack(
+        *GV, /* const AbstractAttribute *AA */ nullptr, UsedAssumedInformation);
+    Initializer = *AssumedGV;
+    if (!Initializer)
+      return nullptr;
+  } else {
+    if (!GV->hasLocalLinkage() && !(GV->isConstant() && GV->hasInitializer()))
+      return nullptr;
+    if (!GV->hasInitializer())
+      return UndefValue::get(&Ty);
+
+    if (!Initializer)
+      Initializer = GV->getInitializer();
+  }
 
   if (RangePtr && !RangePtr->offsetOrSizeAreUnknown()) {
     APInt Offset = APInt(64, RangePtr->Offset);
-    return ConstantFoldLoadFromConst(GV->getInitializer(), &Ty, Offset, DL);
+    return ConstantFoldLoadFromConst(Initializer, &Ty, Offset, DL);
   }
 
-  return ConstantFoldLoadFromUniformValue(GV->getInitializer(), &Ty);
+  return ConstantFoldLoadFromUniformValue(Initializer, &Ty);
 }
 
 bool AA::isValidInScope(const Value &V, const Function *Scope) {
@@ -481,7 +495,7 @@ static bool getPotentialCopiesOfMemoryValue(
     if (IsLoad && !HasBeenWrittenTo && !Range.isUnassigned()) {
       const DataLayout &DL = A.getDataLayout();
       Value *InitialValue =
-          AA::getInitialValueForObj(Obj, *I.getType(), TLI, DL, &Range);
+          AA::getInitialValueForObj(A, Obj, *I.getType(), TLI, DL, &Range);
       if (!InitialValue) {
         LLVM_DEBUG(dbgs() << "Could not determine required initial value of "
                              "underlying object, abort!\n");