[AST] Use IntrusiveRefCntPtr for Introspection LocationCall.
authorNathan James <n.james93@hotmail.co.uk>
Wed, 14 Apr 2021 23:12:21 +0000 (00:12 +0100)
committerNathan James <n.james93@hotmail.co.uk>
Wed, 14 Apr 2021 23:12:22 +0000 (00:12 +0100)
Reviewed By: steveire

Differential Revision: https://reviews.llvm.org/D100378

clang/include/clang/Tooling/NodeIntrospection.h
clang/lib/Tooling/DumpTool/generate_cxx_src_locs.py
clang/lib/Tooling/NodeIntrospection.cpp

index 28007c4..70bfeba 100644 (file)
@@ -15,8 +15,7 @@
 
 #include "clang/AST/ASTTypeTraits.h"
 #include "clang/AST/DeclarationName.h"
-
-#include <memory>
+#include "llvm/ADT/IntrusiveRefCntPtr.h"
 #include <set>
 
 namespace clang {
@@ -30,16 +29,19 @@ class CXXBaseSpecifier;
 
 namespace tooling {
 
-class LocationCall {
+class LocationCall;
+using SharedLocationCall = llvm::IntrusiveRefCntPtr<LocationCall>;
+
+class LocationCall : public llvm::ThreadSafeRefCountedBase<LocationCall> {
 public:
   enum LocationCallFlags { NoFlags, ReturnsPointer, IsCast };
-  LocationCall(std::shared_ptr<LocationCall> on, std::string name,
+  LocationCall(SharedLocationCall on, std::string name,
                LocationCallFlags flags = NoFlags)
-      : m_on(on), m_name(name), m_flags(flags) {}
-  LocationCall(std::shared_ptr<LocationCall> on, std::string name,
+      : m_flags(flags), m_on(std::move(on)), m_name(name) {}
+  LocationCall(SharedLocationCall on, std::string name,
                std::vector<std::string> const &args,
                LocationCallFlags flags = NoFlags)
-      : m_on(on), m_name(name), m_flags(flags) {}
+      : m_flags(flags), m_on(std::move(on)), m_name(name) {}
 
   LocationCall *on() const { return m_on.get(); }
   StringRef name() const { return m_name; }
@@ -48,10 +50,10 @@ public:
   bool isCast() const { return m_flags & IsCast; }
 
 private:
-  std::shared_ptr<LocationCall> m_on;
+  LocationCallFlags m_flags;
+  SharedLocationCall m_on;
   std::string m_name;
   std::vector<std::string> m_args;
-  LocationCallFlags m_flags;
 };
 
 class LocationCallFormatterCpp {
@@ -61,20 +63,20 @@ public:
 
 namespace internal {
 struct RangeLessThan {
-  bool operator()(
-      std::pair<SourceRange, std::shared_ptr<LocationCall>> const &LHS,
-      std::pair<SourceRange, std::shared_ptr<LocationCall>> const &RHS) const;
+  bool operator()(std::pair<SourceRange, SharedLocationCall> const &LHS,
+                  std::pair<SourceRange, SharedLocationCall> const &RHS) const;
+  bool
+  operator()(std::pair<SourceLocation, SharedLocationCall> const &LHS,
+             std::pair<SourceLocation, SharedLocationCall> const &RHS) const;
 };
+
 } // namespace internal
 
-template <typename T, typename U, typename Comp = std::less<std::pair<T, U>>>
-using UniqueMultiMap = std::set<std::pair<T, U>, Comp>;
+template <typename T, typename U>
+using UniqueMultiMap = std::set<std::pair<T, U>, internal::RangeLessThan>;
 
-using SourceLocationMap =
-    UniqueMultiMap<SourceLocation, std::shared_ptr<LocationCall>>;
-using SourceRangeMap =
-    UniqueMultiMap<SourceRange, std::shared_ptr<LocationCall>,
-                   internal::RangeLessThan>;
+using SourceLocationMap = UniqueMultiMap<SourceLocation, SharedLocationCall>;
+using SourceRangeMap = UniqueMultiMap<SourceRange, SharedLocationCall>;
 
 struct NodeLocationAccessors {
   SourceLocationMap LocationAccessors;
index e89ed1c..0adebeb 100755 (executable)
@@ -33,7 +33,7 @@ using RangeAndString = SourceRangeMap::value_type;
     def GenerateBaseGetLocationsDeclaration(self, CladeName):
         self.implementationContent += \
             """
-void GetLocationsImpl(std::shared_ptr<LocationCall> const& Prefix,
+void GetLocationsImpl(SharedLocationCall const& Prefix,
     clang::{0} const *Object, SourceLocationMap &Locs,
     SourceRangeMap &Rngs);
 """.format(CladeName)
@@ -42,7 +42,7 @@ void GetLocationsImpl(std::shared_ptr<LocationCall> const& Prefix,
 
         self.implementationContent += \
             """
-static void GetLocations{0}(std::shared_ptr<LocationCall> const& Prefix,
+static void GetLocations{0}(SharedLocationCall const& Prefix,
     clang::{0} const &Object,
     SourceLocationMap &Locs, SourceRangeMap &Rngs)
 {{
@@ -53,7 +53,7 @@ static void GetLocations{0}(std::shared_ptr<LocationCall> const& Prefix,
                 self.implementationContent += \
                     """
   Locs.insert(LocationAndString(Object.{0}(),
-    std::make_shared<LocationCall>(Prefix, "{0}")));
+    llvm::makeIntrusiveRefCnt<LocationCall>(Prefix, "{0}")));
 """.format(locName)
 
             self.implementationContent += '\n'
@@ -63,7 +63,7 @@ static void GetLocations{0}(std::shared_ptr<LocationCall> const& Prefix,
                 self.implementationContent += \
                     """
   Rngs.insert(RangeAndString(Object.{0}(),
-    std::make_shared<LocationCall>(Prefix, "{0}")));
+    llvm::makeIntrusiveRefCnt<LocationCall>(Prefix, "{0}")));
 """.format(rngName)
 
             self.implementationContent += '\n'
@@ -83,7 +83,7 @@ static void GetLocations{0}(std::shared_ptr<LocationCall> const& Prefix,
             'GetLocations(clang::{0} const *Object)'.format(CladeName)
         ImplSignature = \
             """
-GetLocationsImpl(std::shared_ptr<LocationCall> const& Prefix,
+GetLocationsImpl(SharedLocationCall const& Prefix,
     clang::{0} const *Object, SourceLocationMap &Locs,
     SourceRangeMap &Rngs)
 """.format(CladeName)
@@ -108,7 +108,7 @@ if (auto Derived = llvm::dyn_cast<clang::{0}>(Object)) {{
             """
 {0} NodeIntrospection::{1} {{
   NodeLocationAccessors Result;
-  std::shared_ptr<LocationCall> Prefix;
+  SharedLocationCall Prefix;
 
   GetLocationsImpl(Prefix, Object, Result.LocationAccessors,
                    Result.RangeAccessors);
index 89e8c19..bb0e6ec 100644 (file)
@@ -36,8 +36,8 @@ std::string LocationCallFormatterCpp::format(LocationCall *Call) {
 
 namespace internal {
 bool RangeLessThan::operator()(
-    std::pair<SourceRange, std::shared_ptr<LocationCall>> const &LHS,
-    std::pair<SourceRange, std::shared_ptr<LocationCall>> const &RHS) const {
+    std::pair<SourceRange, SharedLocationCall> const &LHS,
+    std::pair<SourceRange, SharedLocationCall> const &RHS) const {
   if (!LHS.first.isValid() || !RHS.first.isValid())
     return false;
 
@@ -53,6 +53,13 @@ bool RangeLessThan::operator()(
 
   return LHS.second->name() < RHS.second->name();
 }
+bool RangeLessThan::operator()(
+    std::pair<SourceLocation, SharedLocationCall> const &LHS,
+    std::pair<SourceLocation, SharedLocationCall> const &RHS) const {
+  if (LHS.first == RHS.first)
+    return LHS.second->name() < RHS.second->name();
+  return LHS.first < RHS.first;
+}
 } // namespace internal
 
 } // namespace tooling