[clangd] Boost code completion results that are narrowly scoped (local, members)
authorSam McCall <sam.mccall@gmail.com>
Tue, 5 Jun 2018 16:30:25 +0000 (16:30 +0000)
committerSam McCall <sam.mccall@gmail.com>
Tue, 5 Jun 2018 16:30:25 +0000 (16:30 +0000)
Summary:
This signal is considered a relevance rather than a quality signal because it's
dependent on the query (the fact that it's completion, and implicitly the query
context).

This is part of the effort to reduce reliance on Sema priority, so we can have
consistent ranking between Index and Sema results.

Reviewers: ioeric

Subscribers: klimek, ilya-biryukov, MaskRay, jkorous, cfe-commits

Differential Revision: https://reviews.llvm.org/D47762

llvm-svn: 334026

clang-tools-extra/clangd/ClangdUnit.cpp
clang-tools-extra/clangd/ClangdUnit.h
clang-tools-extra/clangd/CodeComplete.cpp
clang-tools-extra/clangd/Quality.cpp
clang-tools-extra/clangd/Quality.h
clang-tools-extra/unittests/clangd/QualityTests.cpp
clang-tools-extra/unittests/clangd/TestTU.cpp
clang-tools-extra/unittests/clangd/TestTU.h

index bf5de6a..cd4104c 100644 (file)
@@ -51,11 +51,11 @@ template <class T> std::size_t getUsedBytes(const std::vector<T> &Vec) {
 
 class DeclTrackingASTConsumer : public ASTConsumer {
 public:
-  DeclTrackingASTConsumer(std::vector<const Decl *> &TopLevelDecls)
+  DeclTrackingASTConsumer(std::vector<Decl *> &TopLevelDecls)
       : TopLevelDecls(TopLevelDecls) {}
 
   bool HandleTopLevelDecl(DeclGroupRef DG) override {
-    for (const Decl *D : DG) {
+    for (Decl *D : DG) {
       // ObjCMethodDecl are not actually top-level decls.
       if (isa<ObjCMethodDecl>(D))
         continue;
@@ -66,14 +66,12 @@ public:
   }
 
 private:
-  std::vector<const Decl *> &TopLevelDecls;
+  std::vector<Decl *> &TopLevelDecls;
 };
 
 class ClangdFrontendAction : public SyntaxOnlyAction {
 public:
-  std::vector<const Decl *> takeTopLevelDecls() {
-    return std::move(TopLevelDecls);
-  }
+  std::vector<Decl *> takeTopLevelDecls() { return std::move(TopLevelDecls); }
 
 protected:
   std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI,
@@ -82,7 +80,7 @@ protected:
   }
 
 private:
-  std::vector<const Decl *> TopLevelDecls;
+  std::vector<Decl *> TopLevelDecls;
 };
 
 class CppFilePreambleCallbacks : public PreambleCallbacks {
@@ -174,7 +172,7 @@ ParsedAST::Build(std::unique_ptr<clang::CompilerInvocation> CI,
   // CompilerInstance won't run this callback, do it directly.
   ASTDiags.EndSourceFile();
 
-  std::vector<const Decl *> ParsedDecls = Action->takeTopLevelDecls();
+  std::vector<Decl *> ParsedDecls = Action->takeTopLevelDecls();
   std::vector<Diag> Diags = ASTDiags.take();
   // Add diagnostics from the preamble, if any.
   if (Preamble)
@@ -210,7 +208,7 @@ const Preprocessor &ParsedAST::getPreprocessor() const {
   return Clang->getPreprocessor();
 }
 
-ArrayRef<const Decl *> ParsedAST::getLocalTopLevelDecls() {
+ArrayRef<Decl *> ParsedAST::getLocalTopLevelDecls() {
   return LocalTopLevelDecls;
 }
 
@@ -261,7 +259,7 @@ PreambleData::PreambleData(PrecompiledPreamble Preamble,
 ParsedAST::ParsedAST(std::shared_ptr<const PreambleData> Preamble,
                      std::unique_ptr<CompilerInstance> Clang,
                      std::unique_ptr<FrontendAction> Action,
-                     std::vector<const Decl *> LocalTopLevelDecls,
+                     std::vector<Decl *> LocalTopLevelDecls,
                      std::vector<Diag> Diags, std::vector<Inclusion> Inclusions)
     : Preamble(std::move(Preamble)), Clang(std::move(Clang)),
       Action(std::move(Action)), Diags(std::move(Diags)),
index 8d9c3f0..c678d56 100644 (file)
@@ -91,7 +91,8 @@ public:
 
   /// This function returns top-level decls present in the main file of the AST.
   /// The result does not include the decls that come from the preamble.
-  ArrayRef<const Decl *> getLocalTopLevelDecls();
+  /// (These should be const, but RecursiveASTVisitor requires Decl*).
+  ArrayRef<Decl *> getLocalTopLevelDecls();
 
   const std::vector<Diag> &getDiagnostics() const;
 
@@ -104,8 +105,8 @@ private:
   ParsedAST(std::shared_ptr<const PreambleData> Preamble,
             std::unique_ptr<CompilerInstance> Clang,
             std::unique_ptr<FrontendAction> Action,
-            std::vector<const Decl *> LocalTopLevelDecls,
-            std::vector<Diag> Diags, std::vector<Inclusion> Inclusions);
+            std::vector<Decl *> LocalTopLevelDecls, std::vector<Diag> Diags,
+            std::vector<Inclusion> Inclusions);
 
   // In-memory preambles must outlive the AST, it is important that this member
   // goes before Clang and Action.
@@ -122,7 +123,7 @@ private:
   std::vector<Diag> Diags;
   // Top-level decls inside the current file. Not that this does not include
   // top-level decls from the preamble.
-  std::vector<const Decl *> LocalTopLevelDecls;
+  std::vector<Decl *> LocalTopLevelDecls;
   std::vector<Inclusion> Inclusions;
 };
 
index 7de69e3..6fb1d9a 100644 (file)
@@ -1007,12 +1007,15 @@ private:
 
     SymbolQualitySignals Quality;
     SymbolRelevanceSignals Relevance;
+    Relevance.Query = SymbolRelevanceSignals::CodeComplete;
     if (auto FuzzyScore = Filter->match(C.Name))
       Relevance.NameMatch = *FuzzyScore;
     else
       return;
-    if (IndexResult)
+    if (IndexResult) {
       Quality.merge(*IndexResult);
+      Relevance.merge(*IndexResult);
+    }
     if (SemaResult) {
       Quality.merge(*SemaResult);
       Relevance.merge(*SemaResult);
index 831bb98..9ac9955 100644 (file)
@@ -67,6 +67,28 @@ raw_ostream &operator<<(raw_ostream &OS, const SymbolQualitySignals &S) {
   return OS;
 }
 
+static SymbolRelevanceSignals::AccessibleScope
+ComputeScope(const NamedDecl &D) {
+  bool InClass;
+  for (const DeclContext *DC = D.getDeclContext(); !DC->isFileContext();
+       DC = DC->getParent()) {
+    if (DC->isFunctionOrMethod())
+      return SymbolRelevanceSignals::FunctionScope;
+    InClass = InClass || DC->isRecord();
+  }
+  if (InClass)
+    return SymbolRelevanceSignals::ClassScope;
+  // This threshold could be tweaked, e.g. to treat module-visible as global.
+  if (D.getLinkageInternal() < ExternalLinkage)
+    return SymbolRelevanceSignals::FileScope;
+  return SymbolRelevanceSignals::GlobalScope;
+}
+
+void SymbolRelevanceSignals::merge(const Symbol &IndexResult) {
+  // FIXME: Index results always assumed to be at global scope. If Scope becomes
+  // relevant to non-completion requests, we should recognize class members etc.
+}
+
 void SymbolRelevanceSignals::merge(const CodeCompletionResult &SemaCCResult) {
   if (SemaCCResult.Availability == CXAvailability_NotAvailable ||
       SemaCCResult.Availability == CXAvailability_NotAccessible)
@@ -79,16 +101,41 @@ void SymbolRelevanceSignals::merge(const CodeCompletionResult &SemaCCResult) {
         hasDeclInMainFile(*SemaCCResult.Declaration) ? 1.0 : 0.0;
     ProximityScore = std::max(DeclProximity, ProximityScore);
   }
+
+  // Declarations are scoped, others (like macros) are assumed global.
+  if (SemaCCResult.Kind == CodeCompletionResult::RK_Declaration)
+    Scope = std::min(Scope, ComputeScope(*SemaCCResult.Declaration));
 }
 
 float SymbolRelevanceSignals::evaluate() const {
+  float Score = 1;
+
   if (Forbidden)
     return 0;
 
-  float Score = NameMatch;
+  Score *= NameMatch;
+
   // Proximity scores are [0,1] and we translate them into a multiplier in the
   // range from 1 to 2.
   Score *= 1 + ProximityScore;
+
+  // Symbols like local variables may only be referenced within their scope.
+  // Conversely if we're in that scope, it's likely we'll reference them.
+  if (Query == CodeComplete) {
+    // The narrower the scope where a symbol is visible, the more likely it is
+    // to be relevant when it is available.
+    switch (Scope) {
+    case GlobalScope:
+      break;
+    case FileScope:
+      Score *= 1.5;
+    case ClassScope:
+      Score *= 2;
+    case FunctionScope:
+      Score *= 4;
+    }
+  }
+
   return Score;
 }
 raw_ostream &operator<<(raw_ostream &OS, const SymbolRelevanceSignals &S) {
index b83a7eb..ae6f7d4 100644 (file)
@@ -67,7 +67,21 @@ struct SymbolRelevanceSignals {
   /// Proximity between best declaration and the query. [0-1], 1 is closest.
   float ProximityScore = 0;
 
+  // An approximate measure of where we expect the symbol to be used.
+  enum AccessibleScope {
+    FunctionScope,
+    ClassScope,
+    FileScope,
+    GlobalScope,
+  } Scope = GlobalScope;
+
+  enum QueryType {
+    CodeComplete,
+    Generic,
+  } Query = Generic;
+
   void merge(const CodeCompletionResult &SemaResult);
+  void merge(const Symbol &IndexResult);
 
   // Condense these signals down to a single number, higher is better.
   float evaluate() const;
index b8684d2..0d80570 100644 (file)
@@ -69,6 +69,8 @@ TEST(QualityTests, SymbolRelevanceSignalExtraction) {
 
     [[deprecated]]
     int deprecated() { return 0; }
+
+    namespace { struct X { void y() { int z; } }; }
   )cpp";
   auto AST = Test.build();
 
@@ -78,6 +80,7 @@ TEST(QualityTests, SymbolRelevanceSignalExtraction) {
                                        /*Accessible=*/false));
   EXPECT_EQ(Relevance.NameMatch, SymbolRelevanceSignals().NameMatch);
   EXPECT_TRUE(Relevance.Forbidden);
+  EXPECT_EQ(Relevance.Scope, SymbolRelevanceSignals::GlobalScope);
 
   Relevance = {};
   Relevance.merge(CodeCompletionResult(&findDecl(AST, "main"), 42));
@@ -88,6 +91,16 @@ TEST(QualityTests, SymbolRelevanceSignalExtraction) {
   Relevance = {};
   Relevance.merge(CodeCompletionResult(&findDecl(AST, "header_main"), 42));
   EXPECT_FLOAT_EQ(Relevance.ProximityScore, 1.0) << "Current file and header";
+
+  Relevance = {};
+  Relevance.merge(CodeCompletionResult(&findAnyDecl(AST, "X"), 42));
+  EXPECT_EQ(Relevance.Scope, SymbolRelevanceSignals::FileScope);
+  Relevance = {};
+  Relevance.merge(CodeCompletionResult(&findAnyDecl(AST, "y"), 42));
+  EXPECT_EQ(Relevance.Scope, SymbolRelevanceSignals::ClassScope);
+  Relevance = {};
+  Relevance.merge(CodeCompletionResult(&findAnyDecl(AST, "z"), 42));
+  EXPECT_EQ(Relevance.Scope, SymbolRelevanceSignals::FunctionScope);
 }
 
 // Do the signals move the scores in the direction we expect?
@@ -127,6 +140,12 @@ TEST(QualityTests, SymbolRelevanceSignalsSanity) {
   SymbolRelevanceSignals WithProximity;
   WithProximity.ProximityScore = 0.2f;
   EXPECT_GT(WithProximity.evaluate(), Default.evaluate());
+
+  SymbolRelevanceSignals Scoped;
+  Scoped.Scope = SymbolRelevanceSignals::FileScope;
+  EXPECT_EQ(Scoped.evaluate(), Default.evaluate());
+  Scoped.Query = SymbolRelevanceSignals::CodeComplete;
+  EXPECT_GT(Scoped.evaluate(), Default.evaluate());
 }
 
 TEST(QualityTests, SortText) {
index 259f36f..ed78098 100644 (file)
@@ -10,6 +10,7 @@
 #include "TestFS.h"
 #include "index/FileIndex.h"
 #include "index/MemIndex.h"
+#include "clang/AST/RecursiveASTVisitor.h"
 #include "clang/Frontend/CompilerInvocation.h"
 #include "clang/Frontend/PCHContainerOperations.h"
 #include "clang/Frontend/Utils.h"
@@ -49,7 +50,6 @@ std::unique_ptr<SymbolIndex> TestTU::index() const {
   return MemIndex::build(headerSymbols());
 }
 
-// Look up a symbol by qualified name, which must be unique.
 const Symbol &findSymbol(const SymbolSlab &Slab, llvm::StringRef QName) {
   const Symbol *Result = nullptr;
   for (const Symbol &S : Slab) {
@@ -92,5 +92,26 @@ const NamedDecl &findDecl(ParsedAST &AST, llvm::StringRef QName) {
   return LookupDecl(*Scope, Components.back());
 }
 
+const NamedDecl &findAnyDecl(ParsedAST &AST, llvm::StringRef Name) {
+  struct Visitor : RecursiveASTVisitor<Visitor> {
+    llvm::StringRef Name;
+    llvm::SmallVector<const NamedDecl *, 1> Decls;
+    bool VisitNamedDecl(const NamedDecl *ND) {
+      if (auto *ID = ND->getIdentifier())
+        if (ID->getName() == Name)
+          Decls.push_back(ND);
+      return true;
+    }
+  } Visitor;
+  Visitor.Name = Name;
+  for (Decl *D : AST.getLocalTopLevelDecls())
+    Visitor.TraverseDecl(D);
+  if (Visitor.Decls.size() != 1) {
+    ADD_FAILURE() << Visitor.Decls.size() << " symbols named " << Name;
+    assert(Visitor.Decls.size() == 1);
+  }
+  return *Visitor.Decls.front();
+}
+
 } // namespace clangd
 } // namespace clang
index 5181d40..3284aca 100644 (file)
@@ -53,6 +53,8 @@ struct TestTU {
 const Symbol &findSymbol(const SymbolSlab &, llvm::StringRef QName);
 // Look up an AST symbol by qualified name, which must be unique and top-level.
 const NamedDecl &findDecl(ParsedAST &AST, llvm::StringRef QName);
+// Look up a main-file AST symbol by unqualified name, which must be unique.
+const NamedDecl &findAnyDecl(ParsedAST &AST, llvm::StringRef Name);
 
 } // namespace clangd
 } // namespace clang