Go-to-type on smart_ptr<Foo> now also shows Foo
authorTom Praschan <13141438+tom-anders@users.noreply.github.com>
Mon, 11 Jul 2022 10:13:35 +0000 (12:13 +0200)
committerSam McCall <sam.mccall@gmail.com>
Mon, 11 Jul 2022 10:13:47 +0000 (12:13 +0200)
Fixes clangd/clangd#1026

Reviewed By: sammccall

Differential Revision: https://reviews.llvm.org/D128826

clang-tools-extra/clangd/HeuristicResolver.h
clang-tools-extra/clangd/XRefs.cpp
clang-tools-extra/clangd/unittests/XRefsTests.cpp

index 26a5bc3..2b66e4b 100644 (file)
@@ -69,6 +69,11 @@ public:
   const Type *
   resolveNestedNameSpecifierToType(const NestedNameSpecifier *NNS) const;
 
+  // Given the type T of a dependent expression that appears of the LHS of a
+  // "->", heuristically find a corresponding pointee type in whose scope we
+  // could look up the name appearing on the RHS.
+  const Type *getPointeeType(const Type *T) const;
+
 private:
   ASTContext &Ctx;
 
@@ -89,11 +94,6 @@ private:
   // `E`.
   const Type *resolveExprToType(const Expr *E) const;
   std::vector<const NamedDecl *> resolveExprToDecls(const Expr *E) const;
-
-  // Given the type T of a dependent expression that appears of the LHS of a
-  // "->", heuristically find a corresponding pointee type in whose scope we
-  // could look up the name appearing on the RHS.
-  const Type *getPointeeType(const Type *T) const;
 };
 
 } // namespace clangd
index fa8320e..c620b38 100644 (file)
@@ -9,6 +9,7 @@
 #include "AST.h"
 #include "FindSymbols.h"
 #include "FindTarget.h"
+#include "HeuristicResolver.h"
 #include "ParsedAST.h"
 #include "Protocol.h"
 #include "Quality.h"
@@ -1907,38 +1908,54 @@ static QualType typeForNode(const SelectionTree::Node *N) {
   return QualType();
 }
 
-// Given a type targeted by the cursor, return a type that's more interesting
+// Given a type targeted by the cursor, return one or more types that are more interesting
 // to target.
-static QualType unwrapFindType(QualType T) {
+static void unwrapFindType(
+    QualType T, const HeuristicResolver* H, llvm::SmallVector<QualType>& Out) {
   if (T.isNull())
-    return T;
+    return;
 
   // If there's a specific type alias, point at that rather than unwrapping.
   if (const auto* TDT = T->getAs<TypedefType>())
-    return QualType(TDT, 0);
+    return Out.push_back(QualType(TDT, 0));
 
   // Pointers etc => pointee type.
   if (const auto *PT = T->getAs<PointerType>())
-    return unwrapFindType(PT->getPointeeType());
+    return unwrapFindType(PT->getPointeeType(), H, Out);
   if (const auto *RT = T->getAs<ReferenceType>())
-    return unwrapFindType(RT->getPointeeType());
+    return unwrapFindType(RT->getPointeeType(), H, Out);
   if (const auto *AT = T->getAsArrayTypeUnsafe())
-    return unwrapFindType(AT->getElementType());
-  // FIXME: use HeuristicResolver to unwrap smart pointers?
+    return unwrapFindType(AT->getElementType(), H, Out);
 
   // Function type => return type.
   if (auto *FT = T->getAs<FunctionType>())
-    return unwrapFindType(FT->getReturnType());
+    return unwrapFindType(FT->getReturnType(), H, Out);
   if (auto *CRD = T->getAsCXXRecordDecl()) {
     if (CRD->isLambda())
-      return unwrapFindType(CRD->getLambdaCallOperator()->getReturnType());
+      return unwrapFindType(CRD->getLambdaCallOperator()->getReturnType(), H, Out);
     // FIXME: more cases we'd prefer the return type of the call operator?
     //        std::function etc?
   }
 
-  return T;
+  // For smart pointer types, add the underlying type
+  if (H)
+    if (const auto* PointeeType = H->getPointeeType(T.getNonReferenceType().getTypePtr())) {
+        unwrapFindType(QualType(PointeeType, 0), H, Out);
+        return Out.push_back(T);
+    }
+
+  return Out.push_back(T);
 }
 
+// Convenience overload, to allow calling this without the out-parameter
+static llvm::SmallVector<QualType> unwrapFindType(
+    QualType T, const HeuristicResolver* H) {
+    llvm::SmallVector<QualType> Result;
+    unwrapFindType(T, H, Result);
+    return Result;
+}
+
+
 std::vector<LocatedSymbol> findType(ParsedAST &AST, Position Pos) {
   const SourceManager &SM = AST.getSourceManager();
   auto Offset = positionToOffset(SM.getBufferData(SM.getMainFileID()), Pos);
@@ -1951,10 +1968,16 @@ std::vector<LocatedSymbol> findType(ParsedAST &AST, Position Pos) {
   // The general scheme is: position -> AST node -> type -> declaration.
   auto SymbolsFromNode =
       [&AST](const SelectionTree::Node *N) -> std::vector<LocatedSymbol> {
-    QualType Type = unwrapFindType(typeForNode(N));
-    if (Type.isNull())
-      return {};
-    return locateSymbolForType(AST, Type);
+    std::vector<LocatedSymbol> LocatedSymbols;
+
+    // NOTE: unwrapFindType might return duplicates for something like
+    // unique_ptr<unique_ptr<T>>. Let's *not* remove them, because it gives you some
+    // information about the type you may have not known before
+    // (since unique_ptr<unique_ptr<T>> != unique_ptr<T>).
+    for (const QualType& Type : unwrapFindType(typeForNode(N), AST.getHeuristicResolver()))
+        llvm::copy(locateSymbolForType(AST, Type), std::back_inserter(LocatedSymbols));
+
+    return LocatedSymbols;
   };
   SelectionTree::createEach(AST.getASTContext(), AST.getTokens(), *Offset,
                             *Offset, [&](SelectionTree ST) {
index 9721f30..9294aee 100644 (file)
@@ -1786,11 +1786,11 @@ TEST(FindImplementations, CaptureDefintion) {
 
 TEST(FindType, All) {
   Annotations HeaderA(R"cpp(
-    struct [[Target]] { operator int() const; };
+    struct $Target[[Target]] { operator int() const; };
     struct Aggregate { Target a, b; };
     Target t;
 
-    template <typename T> class smart_ptr {
+    template <typename T> class $smart_ptr[[smart_ptr]] {
       T& operator*();
       T* operator->();
       T* get();
@@ -1829,11 +1829,11 @@ TEST(FindType, All) {
     ASSERT_GT(A.points().size(), 0u) << Case;
     for (auto Pos : A.points())
       EXPECT_THAT(findType(AST, Pos),
-                  ElementsAre(sym("Target", HeaderA.range(), HeaderA.range())))
+                  ElementsAre(
+                    sym("Target", HeaderA.range("Target"), HeaderA.range("Target"))))
           << Case;
   }
 
-  // FIXME: We'd like these cases to work. Fix them and move above.
   for (const llvm::StringRef Case : {
            "smart_ptr<Target> ^tsmart;",
        }) {
@@ -1842,7 +1842,10 @@ TEST(FindType, All) {
     ParsedAST AST = TU.build();
 
     EXPECT_THAT(findType(AST, A.point()),
-                Not(Contains(sym("Target", HeaderA.range(), HeaderA.range()))))
+                UnorderedElementsAre(
+                  sym("Target", HeaderA.range("Target"), HeaderA.range("Target")),
+                  sym("smart_ptr", HeaderA.range("smart_ptr"), HeaderA.range("smart_ptr"))
+                ))
         << Case;
   }
 }