From a8b55b6939a5962d5b2bf1a57980562d6f3045e5 Mon Sep 17 00:00:00 2001 From: Utkarsh Saxena Date: Tue, 22 Sep 2020 07:56:08 +0200 Subject: [PATCH] [clangd] Use Decision Forest to score code completions. By default clangd will score a code completion item using heuristics model. Scoring can be done by Decision Forest model by passing `--ranking_model=decision_forest` to clangd. Features omitted from the model: - `NameMatch` is excluded because the final score must be multiplicative in `NameMatch` to allow rescoring by the editor. - `NeedsFixIts` is excluded because the generating dataset that needs 'fixits' is non-trivial. There are multiple ways (heuristics) to combine the above two features with the prediction of the DF: - `NeedsFixIts` is used as is with a penalty of `0.5`. Various alternatives of combining NameMatch `N` and Decision forest Prediction `P` - N * scale(P, 0, 1): Linearly scale the output of model to range [0, 1] - N * a^P: - More natural: Prediction of each Decision Tree can be considered as a multiplicative boost (like NameMatch) - Ordering is independent of the absolute value of P. Order of two items is proportional to `a^{difference in model prediction score}`. Higher `a` gives higher weightage to model output as compared to NameMatch score. Baseline MRR = 0.619 MRR for various combinations: N * P = 0.6346, advantage%=2.5768 N * 1.1^P = 0.6600, advantage%=6.6853 N * **1.2**^P = 0.6669, advantage%=**7.8005** N * **1.3**^P = 0.6668, advantage%=**7.7795** N * **1.4**^P = 0.6659, advantage%=**7.6270** N * 1.5^P = 0.6646, advantage%=7.4200 N * 1.6^P = 0.6636, advantage%=7.2671 N * 1.7^P = 0.6629, advantage%=7.1450 N * 2^P = 0.6612, advantage%=6.8673 N * 2.5^P = 0.6598, advantage%=6.6491 N * 3^P = 0.6590, advantage%=6.5242 N * scaled[0, 1] = 0.6465, advantage%=4.5054 Differential Revision: https://reviews.llvm.org/D88281 --- clang-tools-extra/clangd/CodeComplete.cpp | 48 ++++++++++++++++++---- clang-tools-extra/clangd/CodeComplete.h | 16 ++++++++ clang-tools-extra/clangd/Quality.cpp | 29 +++++++++++++ clang-tools-extra/clangd/Quality.h | 7 ++++ clang-tools-extra/clangd/tool/ClangdMain.cpp | 22 ++++++++++ .../clangd/unittests/CodeCompleteTests.cpp | 41 ++++++++++++++---- 6 files changed, 145 insertions(+), 18 deletions(-) diff --git a/clang-tools-extra/clangd/CodeComplete.cpp b/clang-tools-extra/clangd/CodeComplete.cpp index 4d5b297..90e793f 100644 --- a/clang-tools-extra/clangd/CodeComplete.cpp +++ b/clang-tools-extra/clangd/CodeComplete.cpp @@ -1625,6 +1625,43 @@ private: return Filter->match(C.Name); } + CodeCompletion::Scores + evaluateCompletion(const SymbolQualitySignals &Quality, + const SymbolRelevanceSignals &Relevance) { + using RM = CodeCompleteOptions::CodeCompletionRankingModel; + CodeCompletion::Scores Scores; + switch (Opts.RankingModel) { + case RM::Heuristics: + Scores.Quality = Quality.evaluate(); + Scores.Relevance = Relevance.evaluate(); + Scores.Total = + evaluateSymbolAndRelevance(Scores.Quality, Scores.Relevance); + // NameMatch is in fact a multiplier on total score, so rescoring is + // sound. + Scores.ExcludingName = Relevance.NameMatch + ? Scores.Total / Relevance.NameMatch + : Scores.Quality; + return Scores; + + case RM::DecisionForest: + Scores.Quality = 0; + Scores.Relevance = 0; + // Exponentiating DecisionForest prediction makes the score of each tree a + // multiplciative boost (like NameMatch). This allows us to weigh the + // prediciton score and NameMatch appropriately. + Scores.ExcludingName = pow(Opts.DecisionForestBase, + evaluateDecisionForest(Quality, Relevance)); + // NeedsFixIts is not part of the DecisionForest as generating training + // data that needs fixits is not-feasible. + if (Relevance.NeedsFixIts) + Scores.ExcludingName *= 0.5; + // NameMatch should be a multiplier on total score to support rescoring. + Scores.Total = Relevance.NameMatch * Scores.ExcludingName; + return Scores; + } + llvm_unreachable("Unhandled CodeCompletion ranking model."); + } + // Scores a candidate and adds it to the TopN structure. void addCandidate(TopN &Candidates, CompletionCandidate::Bundle Bundle) { @@ -1632,6 +1669,7 @@ private: SymbolRelevanceSignals Relevance; Relevance.Context = CCContextKind; Relevance.Name = Bundle.front().Name; + Relevance.FilterLength = HeuristicPrefix.Name.size(); Relevance.Query = SymbolRelevanceSignals::CodeComplete; Relevance.FileProximityMatch = FileProximity.getPointer(); if (ScopeProximity) @@ -1680,15 +1718,7 @@ private: } } - CodeCompletion::Scores Scores; - Scores.Quality = Quality.evaluate(); - Scores.Relevance = Relevance.evaluate(); - Scores.Total = evaluateSymbolAndRelevance(Scores.Quality, Scores.Relevance); - // NameMatch is in fact a multiplier on total score, so rescoring is sound. - Scores.ExcludingName = Relevance.NameMatch - ? Scores.Total / Relevance.NameMatch - : Scores.Quality; - + CodeCompletion::Scores Scores = evaluateCompletion(Quality, Relevance); if (Opts.RecordCCResult) Opts.RecordCCResult(toCodeCompletion(Bundle), Quality, Relevance, Scores.Total); diff --git a/clang-tools-extra/clangd/CodeComplete.h b/clang-tools-extra/clangd/CodeComplete.h index beffabd..82a2656 100644 --- a/clang-tools-extra/clangd/CodeComplete.h +++ b/clang-tools-extra/clangd/CodeComplete.h @@ -147,6 +147,22 @@ struct CodeCompleteOptions { std::function RecordCCResult; + + /// Model to use for ranking code completion candidates. + enum CodeCompletionRankingModel { + Heuristics, + DecisionForest, + } RankingModel = Heuristics; + + /// Weight for combining NameMatch and Prediction of DecisionForest. + /// CompletionScore is NameMatch * pow(Base, Prediction). + /// The optimal value of Base largely depends on the semantics of the model + /// and prediction score (e.g. algorithm used during training, number of + /// trees, etc.). Usually if the range of Prediciton is [-20, 20] then a Base + /// in [1.2, 1.7] works fine. + /// Semantics: E.g. the completion score reduces by 50% if the Prediciton + /// score is reduced by 2.6 points for Base = 1.3. + float DecisionForestBase = 1.3f; }; // Semi-structured representation of a code-complete suggestion for our C++ API. diff --git a/clang-tools-extra/clangd/Quality.cpp b/clang-tools-extra/clangd/Quality.cpp index bf0c095..37f1cf6 100644 --- a/clang-tools-extra/clangd/Quality.cpp +++ b/clang-tools-extra/clangd/Quality.cpp @@ -8,6 +8,7 @@ #include "Quality.h" #include "AST.h" +#include "CompletionModel.h" #include "FileDistance.h" #include "SourceCode.h" #include "URI.h" @@ -486,6 +487,34 @@ float evaluateSymbolAndRelevance(float SymbolQuality, float SymbolRelevance) { return SymbolQuality * SymbolRelevance; } +float evaluateDecisionForest(const SymbolQualitySignals &Quality, + const SymbolRelevanceSignals &Relevance) { + Example E; + E.setIsDeprecated(Quality.Deprecated); + E.setIsReservedName(Quality.ReservedName); + E.setIsImplementationDetail(Quality.ImplementationDetail); + E.setNumReferences(Quality.References); + E.setSymbolCategory(Quality.Category); + + SymbolRelevanceSignals::DerivedSignals Derived = + Relevance.calculateDerivedSignals(); + E.setIsNameInContext(Derived.NameMatchesContext); + E.setIsForbidden(Relevance.Forbidden); + E.setIsInBaseClass(Relevance.InBaseClass); + E.setFileProximityDistance(Derived.FileProximityDistance); + E.setSemaFileProximityScore(Relevance.SemaFileProximityScore); + E.setSymbolScopeDistance(Derived.ScopeProximityDistance); + E.setSemaSaysInScope(Relevance.SemaSaysInScope); + E.setScope(Relevance.Scope); + E.setContextKind(Relevance.Context); + E.setIsInstanceMember(Relevance.IsInstanceMember); + E.setHadContextType(Relevance.HadContextType); + E.setHadSymbolType(Relevance.HadSymbolType); + E.setTypeMatchesPreferred(Relevance.TypeMatchesPreferred); + E.setFilterLength(Relevance.FilterLength); + return Evaluate(E); +} + // Produces an integer that sorts in the same order as F. // That is: a < b <==> encodeFloat(a) < encodeFloat(b). static uint32_t encodeFloat(float F) { diff --git a/clang-tools-extra/clangd/Quality.h b/clang-tools-extra/clangd/Quality.h index 04c6ce2..694653e 100644 --- a/clang-tools-extra/clangd/Quality.h +++ b/clang-tools-extra/clangd/Quality.h @@ -77,6 +77,7 @@ struct SymbolQualitySignals { void merge(const CodeCompletionResult &SemaCCResult); void merge(const Symbol &IndexResult); + // FIXME(usx): Rename to evaluateHeuristics(). // Condense these signals down to a single number, higher is better. float evaluate() const; }; @@ -136,6 +137,10 @@ struct SymbolRelevanceSignals { // Whether the item matches the type expected in the completion context. bool TypeMatchesPreferred = false; + /// Length of the unqualified partial name of Symbol typed in + /// CompletionPrefix. + unsigned FilterLength = 0; + /// Set of derived signals computed by calculateDerivedSignals(). Must not be /// set explicitly. struct DerivedSignals { @@ -161,6 +166,8 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &, /// Combine symbol quality and relevance into a single score. float evaluateSymbolAndRelevance(float SymbolQuality, float SymbolRelevance); +float evaluateDecisionForest(const SymbolQualitySignals &Quality, + const SymbolRelevanceSignals &Relevance); /// TopN is a lossy container that preserves only the "best" N elements. template > class TopN { public: diff --git a/clang-tools-extra/clangd/tool/ClangdMain.cpp b/clang-tools-extra/clangd/tool/ClangdMain.cpp index 9660f1b..8e5d6cb 100644 --- a/clang-tools-extra/clangd/tool/ClangdMain.cpp +++ b/clang-tools-extra/clangd/tool/ClangdMain.cpp @@ -167,6 +167,26 @@ opt CodeCompletionParse{ Hidden, }; +opt RankingModel{ + "ranking-model", + cat(Features), + desc("Model to use to rank code-completion items"), + values(clEnumValN(CodeCompleteOptions::Heuristics, "heuristics", + "Use hueristics to rank code completion items"), + clEnumValN(CodeCompleteOptions::DecisionForest, "decision_forest", + "Use Decision Forest model to rank completion items")), + init(CodeCompleteOptions().RankingModel), + Hidden, +}; + +opt DecisionForestBase{ + "decision-forest-base", + cat(Features), + desc("Base for exponentiating the prediction from DecisionForest."), + init(CodeCompleteOptions().DecisionForestBase), + Hidden, +}; + // FIXME: also support "plain" style where signatures are always omitted. enum CompletionStyleFlag { Detailed, Bundled }; opt CompletionStyle{ @@ -739,6 +759,8 @@ clangd accepts flags on the commandline, and in the CLANGD_FLAGS environment var CCOpts.EnableFunctionArgSnippets = EnableFunctionArgSnippets; CCOpts.AllScopes = AllScopesCompletion; CCOpts.RunParser = CodeCompletionParse; + CCOpts.RankingModel = RankingModel; + CCOpts.DecisionForestBase = DecisionForestBase; RealThreadsafeFS TFS; std::vector> ProviderStack; diff --git a/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp b/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp index 460976d..de73bc6 100644 --- a/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp +++ b/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp @@ -10,7 +10,6 @@ #include "ClangdServer.h" #include "CodeComplete.h" #include "Compiler.h" -#include "CompletionModel.h" #include "Matchers.h" #include "Protocol.h" #include "Quality.h" @@ -163,14 +162,38 @@ Symbol withReferences(int N, Symbol S) { return S; } -TEST(DecisionForestRuntime, SanityTest) { - using Example = clangd::Example; - using clangd::Evaluate; - Example E1; - E1.setContextKind(ContextKind::CCC_ArrowMemberAccess); - Example E2; - E2.setContextKind(ContextKind::CCC_SymbolOrNewName); - EXPECT_GT(Evaluate(E1), Evaluate(E2)); +TEST(DecisionForestRankingModel, NameMatchSanityTest) { + clangd::CodeCompleteOptions Opts; + Opts.RankingModel = CodeCompleteOptions::DecisionForest; + auto Results = completions( + R"cpp( +struct MemberAccess { + int ABG(); + int AlphaBetaGamma(); +}; +int func() { MemberAccess().ABG^ } +)cpp", + /*IndexSymbols=*/{}, Opts); + EXPECT_THAT(Results.Completions, + ElementsAre(Named("ABG"), Named("AlphaBetaGamma"))); +} + +TEST(DecisionForestRankingModel, ReferencesAffectRanking) { + clangd::CodeCompleteOptions Opts; + Opts.RankingModel = CodeCompleteOptions::DecisionForest; + constexpr int NumReferences = 100000; + EXPECT_THAT( + completions("int main() { clang^ }", + {ns("clangA"), withReferences(NumReferences, func("clangD"))}, + Opts) + .Completions, + ElementsAre(Named("clangD"), Named("clangA"))); + EXPECT_THAT( + completions("int main() { clang^ }", + {withReferences(NumReferences, ns("clangA")), func("clangD")}, + Opts) + .Completions, + ElementsAre(Named("clangA"), Named("clangD"))); } TEST(CompletionTest, Limit) { -- 2.7.4