[mlir-lsp-server] Add support for textDocument/documentSymbols
authorRiver Riddle <riddleriver@gmail.com>
Thu, 10 Jun 2021 17:57:55 +0000 (10:57 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Thu, 10 Jun 2021 17:58:39 +0000 (10:58 -0700)
This allows for building an outline of the symbols and symbol tables within the IR. This allows for easy navigations to functions/modules and other symbol/symbol table operations within the IR.

Differential Revision: https://reviews.llvm.org/D103729

mlir/include/mlir/Parser/AsmParserState.h
mlir/lib/Parser/AsmParserState.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
mlir/lib/Tools/mlir-lsp-server/MLIRServer.h
mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.cpp
mlir/lib/Tools/mlir-lsp-server/lsp/Protocol.h
mlir/lib/Tools/mlir-lsp-server/lsp/Transport.cpp
mlir/test/mlir-lsp-server/document-symbols.test [new file with mode: 0644]
mlir/test/mlir-lsp-server/initialize-params.test

index a519731..cc010b8 100644 (file)
@@ -53,7 +53,8 @@ public:
       SMDefinition definition;
     };
 
-    OperationDefinition(Operation *op, llvm::SMRange loc) : op(op), loc(loc) {}
+    OperationDefinition(Operation *op, llvm::SMRange loc, llvm::SMLoc endLoc)
+        : op(op), loc(loc), scopeLoc(loc.Start, endLoc) {}
 
     /// The operation representing this definition.
     Operation *op;
@@ -61,6 +62,10 @@ public:
     /// The source location for the operation, i.e. the location of its name.
     llvm::SMRange loc;
 
+    /// The full source range of the operation definition, i.e. a range
+    /// encompassing the start and end of the full operation definition.
+    llvm::SMRange scopeLoc;
+
     /// Source definitions for any result groups of this operation.
     SmallVector<std::pair<unsigned, SMDefinition>> resultGroups;
 
@@ -110,6 +115,10 @@ public:
   /// state.
   iterator_range<OperationDefIterator> getOpDefs() const;
 
+  /// Return the definition for the given operation, or nullptr if the given
+  /// operation does not have a definition.
+  const OperationDefinition *getOpDef(Operation *op) const;
+
   /// Returns (heuristically) the range of an identifier given a SMLoc
   /// corresponding to the start of an identifier location.
   static llvm::SMRange convertIdLocToRange(llvm::SMLoc loc);
@@ -130,7 +139,7 @@ public:
 
   /// Finalize the most recently started operation definition.
   void finalizeOperationDefinition(
-      Operation *op, llvm::SMRange nameLoc,
+      Operation *op, llvm::SMRange nameLoc, llvm::SMLoc endLoc,
       ArrayRef<std::pair<unsigned, llvm::SMLoc>> resultGroups = llvm::None);
 
   /// Start a definition for a region nested under the current operation.
index 31b8684..e8e0479 100644 (file)
@@ -36,8 +36,8 @@ struct AsmParserState::Impl {
     std::unique_ptr<SymbolUseMap> symbolTable;
   };
 
-  /// Resolve any symbol table uses under the given partial operation.
-  void resolveSymbolUses(Operation *op, PartialOpDef &opDef);
+  /// Resolve any symbol table uses in the IR.
+  void resolveSymbolUses();
 
   /// A mapping from operations in the input source file to their parser state.
   SmallVector<std::unique_ptr<OperationDefinition>> operations;
@@ -51,6 +51,10 @@ struct AsmParserState::Impl {
   /// This map should be empty if the parser finishes successfully.
   DenseMap<Value, SmallVector<llvm::SMLoc>> placeholderValueUses;
 
+  /// The symbol table operations within the IR.
+  SmallVector<std::pair<Operation *, std::unique_ptr<SymbolUseMap>>>
+      symbolTableOperations;
+
   /// A stack of partial operation definitions that have been started but not
   /// yet finalized.
   SmallVector<PartialOpDef> partialOperations;
@@ -63,22 +67,21 @@ struct AsmParserState::Impl {
   SymbolTableCollection symbolTable;
 };
 
-void AsmParserState::Impl::resolveSymbolUses(Operation *op,
-                                             PartialOpDef &opDef) {
-  assert(opDef.isSymbolTable() && "expected op to be a symbol table");
-
+void AsmParserState::Impl::resolveSymbolUses() {
   SmallVector<Operation *> symbolOps;
-  for (auto &it : *opDef.symbolTable) {
-    symbolOps.clear();
-    if (failed(symbolTable.lookupSymbolIn(op, it.first.cast<SymbolRefAttr>(),
-                                          symbolOps)))
-      continue;
-
-    for (ArrayRef<llvm::SMRange> useRange : it.second) {
-      for (const auto &symIt : llvm::zip(symbolOps, useRange)) {
-        auto opIt = operationToIdx.find(std::get<0>(symIt));
-        if (opIt != operationToIdx.end())
-          operations[opIt->second]->symbolUses.push_back(std::get<1>(symIt));
+  for (auto &opAndUseMapIt : symbolTableOperations) {
+    for (auto &it : *opAndUseMapIt.second) {
+      symbolOps.clear();
+      if (failed(symbolTable.lookupSymbolIn(
+              opAndUseMapIt.first, it.first.cast<SymbolRefAttr>(), symbolOps)))
+        continue;
+
+      for (ArrayRef<llvm::SMRange> useRange : it.second) {
+        for (const auto &symIt : llvm::zip(symbolOps, useRange)) {
+          auto opIt = operationToIdx.find(std::get<0>(symIt));
+          if (opIt != operationToIdx.end())
+            operations[opIt->second]->symbolUses.push_back(std::get<1>(symIt));
+        }
       }
     }
   }
@@ -112,8 +115,13 @@ auto AsmParserState::getOpDefs() const -> iterator_range<OperationDefIterator> {
   return llvm::make_pointee_range(llvm::makeArrayRef(impl->operations));
 }
 
-/// Returns (heuristically) the range of an identifier given a SMLoc
-/// corresponding to the start of an identifier location.
+auto AsmParserState::getOpDef(Operation *op) const
+    -> const OperationDefinition * {
+  auto it = impl->operationToIdx.find(op);
+  return it == impl->operationToIdx.end() ? nullptr
+                                          : &*impl->operations[it->second];
+}
+
 llvm::SMRange AsmParserState::convertIdLocToRange(llvm::SMLoc loc) {
   if (!loc.isValid())
     return llvm::SMRange();
@@ -124,7 +132,7 @@ llvm::SMRange AsmParserState::convertIdLocToRange(llvm::SMLoc loc) {
   };
 
   const char *curPtr = loc.getPointer();
-  while (isIdentifierChar(*(++curPtr)))
+  while (*curPtr && isIdentifierChar(*(++curPtr)))
     continue;
   return llvm::SMRange(loc, llvm::SMLoc::getFromPointer(curPtr));
 }
@@ -147,8 +155,11 @@ void AsmParserState::finalize(Operation *topLevelOp) {
   Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
 
   // If this operation is a symbol table, resolve any symbol uses.
-  if (partialOpDef.isSymbolTable())
-    impl->resolveSymbolUses(topLevelOp, partialOpDef);
+  if (partialOpDef.isSymbolTable()) {
+    impl->symbolTableOperations.emplace_back(
+        topLevelOp, std::move(partialOpDef.symbolTable));
+  }
+  impl->resolveSymbolUses();
 }
 
 void AsmParserState::startOperationDefinition(const OperationName &opName) {
@@ -156,7 +167,7 @@ void AsmParserState::startOperationDefinition(const OperationName &opName) {
 }
 
 void AsmParserState::finalizeOperationDefinition(
-    Operation *op, llvm::SMRange nameLoc,
+    Operation *op, llvm::SMRange nameLoc, llvm::SMLoc endLoc,
     ArrayRef<std::pair<unsigned, llvm::SMLoc>> resultGroups) {
   assert(!impl->partialOperations.empty() &&
          "expected valid partial operation definition");
@@ -164,7 +175,7 @@ void AsmParserState::finalizeOperationDefinition(
 
   // Build the full operation definition.
   std::unique_ptr<OperationDefinition> def =
-      std::make_unique<OperationDefinition>(op, nameLoc);
+      std::make_unique<OperationDefinition>(op, nameLoc, endLoc);
   for (auto &resultGroup : resultGroups)
     def->resultGroups.emplace_back(resultGroup.first,
                                    convertIdLocToRange(resultGroup.second));
@@ -172,8 +183,10 @@ void AsmParserState::finalizeOperationDefinition(
   impl->operations.emplace_back(std::move(def));
 
   // If this operation is a symbol table, resolve any symbol uses.
-  if (partialOpDef.isSymbolTable())
-    impl->resolveSymbolUses(op, partialOpDef);
+  if (partialOpDef.isSymbolTable()) {
+    impl->symbolTableOperations.emplace_back(
+        op, std::move(partialOpDef.symbolTable));
+  }
 }
 
 void AsmParserState::startRegionDefinition() {
index 4b80158..08d8dd9 100644 (file)
@@ -166,12 +166,12 @@ namespace {
 /// operations.
 class OperationParser : public Parser {
 public:
-  OperationParser(ParserState &state, Operation *topLevelOp);
+  OperationParser(ParserState &state, ModuleOp topLevelOp);
   ~OperationParser();
 
   /// After parsing is finished, this function must be called to see if there
   /// are any remaining issues.
-  ParseResult finalize(Operation *topLevelOp);
+  ParseResult finalize();
 
   //===--------------------------------------------------------------------===//
   // SSA Value Handling
@@ -399,9 +399,8 @@ private:
 };
 } // end anonymous namespace
 
-OperationParser::OperationParser(ParserState &state, Operation *topLevelOp)
-    : Parser(state), opBuilder(topLevelOp->getRegion(0)),
-      topLevelOp(topLevelOp) {
+OperationParser::OperationParser(ParserState &state, ModuleOp topLevelOp)
+    : Parser(state), opBuilder(topLevelOp.getRegion()), topLevelOp(topLevelOp) {
   // The top level operation starts a new name scope.
   pushSSANameScope(/*isIsolated=*/true);
 
@@ -429,7 +428,7 @@ OperationParser::~OperationParser() {
 
 /// After parsing is finished, this function must be called to see if there are
 /// any remaining issues.
-ParseResult OperationParser::finalize(Operation *topLevelOp) {
+ParseResult OperationParser::finalize() {
   // Check for any forward references that are left.  If we find any, error
   // out.
   if (!forwardRefPlaceholders.empty()) {
@@ -466,12 +465,18 @@ ParseResult OperationParser::finalize(Operation *topLevelOp) {
       opOrArgument.get<BlockArgument>().setLoc(locAttr);
   }
 
+  // Pop the top level name scope.
+  if (failed(popSSANameScope()))
+    return failure();
+
+  // Verify that the parsed operations are valid.
+  if (failed(verify(topLevelOp)))
+    return failure();
+
   // If we are populating the parser state, finalize the top-level operation.
   if (state.asmState)
     state.asmState->finalize(topLevelOp);
-
-  // Pop the top level name scope.
-  return popSSANameScope();
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -821,8 +826,9 @@ ParseResult OperationParser::parseOperation() {
         asmResultGroups.emplace_back(resultIt, std::get<2>(record));
         resultIt += std::get<1>(record);
       }
-      state.asmState->finalizeOperationDefinition(op, nameTok.getLocRange(),
-                                                  asmResultGroups);
+      state.asmState->finalizeOperationDefinition(
+          op, nameTok.getLocRange(), /*endLoc=*/getToken().getLoc(),
+          asmResultGroups);
     }
 
     // Add definitions for each of the result groups.
@@ -837,7 +843,8 @@ ParseResult OperationParser::parseOperation() {
 
     // Add this operation to the assembly state if it was provided to populate.
   } else if (state.asmState) {
-    state.asmState->finalizeOperationDefinition(op, nameTok.getLocRange());
+    state.asmState->finalizeOperationDefinition(op, nameTok.getLocRange(),
+                                                /*endLoc=*/getToken().getLoc());
   }
 
   return success();
@@ -1009,7 +1016,8 @@ Operation *OperationParser::parseGenericOperation(Block *insertBlock,
   // If we are populating the parser asm state, finalize this operation
   // definition.
   if (state.asmState)
-    state.asmState->finalizeOperationDefinition(op, nameToken.getLocRange());
+    state.asmState->finalizeOperationDefinition(op, nameToken.getLocRange(),
+                                                /*endLoc=*/getToken().getLoc());
   return op;
 }
 
@@ -2019,6 +2027,10 @@ ParseResult OperationParser::parseRegionBody(
 
   // Add arguments to the entry block.
   if (!entryArguments.empty()) {
+    // If we had named arguments, then don't allow a block name.
+    if (getToken().is(Token::caret_identifier))
+      return emitError("invalid block name in region with named arguments");
+
     for (auto &placeholderArgPair : entryArguments) {
       auto &argInfo = placeholderArgPair.first;
 
@@ -2040,10 +2052,6 @@ ParseResult OperationParser::parseRegionBody(
       if (addDefinition(argInfo, arg))
         return failure();
     }
-
-    // If we had named arguments, then don't allow a block name.
-    if (getToken().is(Token::caret_identifier))
-      return emitError("invalid block name in region with named arguments");
   }
 
   if (parseBlock(block))
@@ -2310,7 +2318,7 @@ ParseResult TopLevelOperationParser::parseTypeAliasDef() {
 ParseResult TopLevelOperationParser::parse(Block *topLevelBlock,
                                            Location parserLoc) {
   // Create a top-level operation to contain the parsed state.
-  OwningOpRef<Operation *> topLevelOp(ModuleOp::create(parserLoc));
+  OwningOpRef<ModuleOp> topLevelOp(ModuleOp::create(parserLoc));
   OperationParser opParser(state, topLevelOp.get());
   while (true) {
     switch (getToken().getKind()) {
@@ -2322,16 +2330,12 @@ ParseResult TopLevelOperationParser::parse(Block *topLevelBlock,
 
     // If we got to the end of the file, then we're done.
     case Token::eof: {
-      if (opParser.finalize(topLevelOp.get()))
-        return failure();
-
-      // Verify that the parsed operations are valid.
-      if (failed(verify(topLevelOp.get())))
+      if (opParser.finalize())
         return failure();
 
       // Splice the blocks of the parsed operation over to the provided
       // top-level block.
-      auto &parsedOps = (*topLevelOp)->getRegion(0).front().getOperations();
+      auto &parsedOps = topLevelOp->getBody()->getOperations();
       auto &destOps = topLevelBlock->getOperations();
       destOps.splice(destOps.empty() ? destOps.end() : std::prev(destOps.end()),
                      parsedOps, parsedOps.begin(), parsedOps.end());
index dae0524..0d043ba 100644 (file)
@@ -56,6 +56,16 @@ struct LSPServer::Impl {
   void onHover(const TextDocumentPositionParams &params,
                Callback<Optional<Hover>> reply);
 
+  //===--------------------------------------------------------------------===//
+  // Document Symbols
+
+  void onDocumentSymbol(const DocumentSymbolParams &params,
+                        Callback<std::vector<DocumentSymbol>> reply);
+
+  //===--------------------------------------------------------------------===//
+  // Fields
+  //===--------------------------------------------------------------------===//
+
   MLIRServer &server;
   JSONTransport &transport;
 
@@ -73,6 +83,7 @@ struct LSPServer::Impl {
 
 void LSPServer::Impl::onInitialize(const InitializeParams &params,
                                    Callback<llvm::json::Value> reply) {
+  // Send a response with the capabilities of this server.
   llvm::json::Object serverCaps{
       {"textDocumentSync",
        llvm::json::Object{
@@ -83,6 +94,11 @@ void LSPServer::Impl::onInitialize(const InitializeParams &params,
       {"definitionProvider", true},
       {"referencesProvider", true},
       {"hoverProvider", true},
+
+      // For now we only support documenting symbols when the client supports
+      // hierarchical symbols.
+      {"documentSymbolProvider",
+       params.capabilities.hierarchicalDocumentSymbol},
   };
 
   llvm::json::Object result{
@@ -166,6 +182,17 @@ void LSPServer::Impl::onHover(const TextDocumentPositionParams &params,
 }
 
 //===----------------------------------------------------------------------===//
+// Document Symbols
+
+void LSPServer::Impl::onDocumentSymbol(
+    const DocumentSymbolParams &params,
+    Callback<std::vector<DocumentSymbol>> reply) {
+  std::vector<DocumentSymbol> symbols;
+  server.findDocumentSymbols(params.textDocument.uri, symbols);
+  reply(std::move(symbols));
+}
+
+//===----------------------------------------------------------------------===//
 // LSPServer
 //===----------------------------------------------------------------------===//
 
@@ -198,6 +225,10 @@ LogicalResult LSPServer::run() {
   // Hover
   messageHandler.method("textDocument/hover", impl.get(), &Impl::onHover);
 
+  // Document Symbols
+  messageHandler.method("textDocument/documentSymbol", impl.get(),
+                        &Impl::onDocumentSymbol);
+
   // Diagnostics
   impl->publishDiagnostics =
       messageHandler.outgoingNotification<PublishDiagnosticsParams>(
index 715bc27..ae60d76 100644 (file)
@@ -298,6 +298,18 @@ struct MLIRDocument {
   buildHoverForBlockArgument(llvm::SMRange hoverRange, BlockArgument arg,
                              const AsmParserState::BlockDefinition &block);
 
+  //===--------------------------------------------------------------------===//
+  // Document Symbols
+  //===--------------------------------------------------------------------===//
+
+  void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
+  void findDocumentSymbols(Operation *op,
+                           std::vector<lsp::DocumentSymbol> &symbols);
+
+  //===--------------------------------------------------------------------===//
+  // Fields
+  //===--------------------------------------------------------------------===//
+
   /// The context used to hold the state contained by the parsed document.
   MLIRContext context;
 
@@ -595,6 +607,50 @@ lsp::Hover MLIRDocument::buildHoverForBlockArgument(
 }
 
 //===----------------------------------------------------------------------===//
+// MLIRDocument: Document Symbols
+//===----------------------------------------------------------------------===//
+
+void MLIRDocument::findDocumentSymbols(
+    std::vector<lsp::DocumentSymbol> &symbols) {
+  for (Operation &op : parsedIR)
+    findDocumentSymbols(&op, symbols);
+}
+
+void MLIRDocument::findDocumentSymbols(
+    Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
+  std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
+
+  // Check for the source information of this operation.
+  if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
+    // If this operation defines a symbol, record it.
+    if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
+      symbols.emplace_back(symbol.getName(),
+                           op->hasTrait<OpTrait::FunctionLike>()
+                               ? lsp::SymbolKind::Function
+                               : lsp::SymbolKind::Class,
+                           getRangeFromLoc(sourceMgr, def->scopeLoc),
+                           getRangeFromLoc(sourceMgr, def->loc));
+      childSymbols = &symbols.back().children;
+
+    } else if (op->hasTrait<OpTrait::SymbolTable>()) {
+      // Otherwise, if this is a symbol table push an anonymous document symbol.
+      symbols.emplace_back("<" + op->getName().getStringRef() + ">",
+                           lsp::SymbolKind::Namespace,
+                           getRangeFromLoc(sourceMgr, def->scopeLoc),
+                           getRangeFromLoc(sourceMgr, def->loc));
+      childSymbols = &symbols.back().children;
+    }
+  }
+
+  // Recurse into the regions of this operation.
+  if (!op->getNumRegions())
+    return;
+  for (Region &region : op->getRegions())
+    for (Operation &childOp : region.getOps())
+      findDocumentSymbols(&childOp, *childSymbols);
+}
+
+//===----------------------------------------------------------------------===//
 // MLIRTextFileChunk
 //===----------------------------------------------------------------------===//
 
@@ -649,6 +705,7 @@ public:
                         std::vector<lsp::Location> &references);
   Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
                                  lsp::Position hoverPos);
+  void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
 
 private:
   /// Find the MLIR document that contains the given position, and update the
@@ -662,6 +719,9 @@ private:
   /// The version of this file.
   int64_t version;
 
+  /// The number of lines in the file.
+  int64_t totalNumLines;
+
   /// The chunks of this file. The order of these chunks is the order in which
   /// they appear in the text file.
   std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
@@ -671,7 +731,7 @@ private:
 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
                            int64_t version, DialectRegistry &registry,
                            std::vector<lsp::Diagnostic> &diagnostics)
-    : contents(fileContents.str()), version(version) {
+    : contents(fileContents.str()), version(version), totalNumLines(0) {
   // Split the file into separate MLIR documents.
   // TODO: Find a way to share the split file marker with other tools. We don't
   // want to use `splitAndProcessBuffer` here, but we do want to make sure this
@@ -702,6 +762,7 @@ MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
     }
     chunks.emplace_back(std::move(chunk));
   }
+  totalNumLines = lineOffset;
 }
 
 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
@@ -743,6 +804,45 @@ Optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
   return hoverInfo;
 }
 
+void MLIRTextFile::findDocumentSymbols(
+    std::vector<lsp::DocumentSymbol> &symbols) {
+  if (chunks.size() == 1)
+    return chunks.front()->document.findDocumentSymbols(symbols);
+
+  // If there are multiple chunks in this file, we create top-level symbols for
+  // each chunk.
+  for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
+    MLIRTextFileChunk &chunk = *chunks[i];
+    lsp::Position startPos(chunk.lineOffset);
+    lsp::Position endPos((i == e - 1) ? totalNumLines - 1
+                                      : chunks[i + 1]->lineOffset);
+    lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
+                               lsp::SymbolKind::Namespace,
+                               /*range=*/lsp::Range(startPos, endPos),
+                               /*selectionRange=*/lsp::Range(startPos));
+    chunk.document.findDocumentSymbols(symbol.children);
+
+    // Fixup the locations of document symbols within this chunk.
+    if (i != 0) {
+      SmallVector<lsp::DocumentSymbol *> symbolsToFix;
+      for (lsp::DocumentSymbol &childSymbol : symbol.children)
+        symbolsToFix.push_back(&childSymbol);
+
+      while (!symbolsToFix.empty()) {
+        lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
+        chunk.adjustLocForChunkOffset(symbol->range);
+        chunk.adjustLocForChunkOffset(symbol->selectionRange);
+
+        for (lsp::DocumentSymbol &childSymbol : symbol->children)
+          symbolsToFix.push_back(&childSymbol);
+      }
+    }
+
+    // Push the symbol for this chunk.
+    symbols.emplace_back(std::move(symbol));
+  }
+}
+
 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
   if (chunks.size() == 1)
     return *chunks.front();
@@ -821,3 +921,10 @@ Optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri,
     return fileIt->second->findHover(uri, hoverPos);
   return llvm::None;
 }
+
+void lsp::MLIRServer::findDocumentSymbols(
+    const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
+  auto fileIt = impl->files.find(uri.file());
+  if (fileIt != impl->files.end())
+    fileIt->second->findDocumentSymbols(symbols);
+}
index 39f9dff..d78d8c3 100644 (file)
@@ -17,6 +17,7 @@ class DialectRegistry;
 
 namespace lsp {
 struct Diagnostic;
+struct DocumentSymbol;
 struct Hover;
 struct Location;
 struct Position;
@@ -55,6 +56,10 @@ public:
   /// couldn't be found.
   Optional<Hover> findHover(const URIForFile &uri, const Position &hoverPos);
 
+  /// Find all of the document symbols within the given file.
+  void findDocumentSymbols(const URIForFile &uri,
+                           std::vector<DocumentSymbol> &symbols);
+
 private:
   struct Impl;
 
index e6bb593..7bce042 100644 (file)
@@ -247,6 +247,28 @@ raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const URIForFile &value) {
 }
 
 //===----------------------------------------------------------------------===//
+// ClientCapabilities
+//===----------------------------------------------------------------------===//
+
+bool mlir::lsp::fromJSON(const llvm::json::Value &value,
+                         ClientCapabilities &result, llvm::json::Path path) {
+  const llvm::json::Object *o = value.getAsObject();
+  if (!o) {
+    path.report("expected object");
+    return false;
+  }
+  if (const llvm::json::Object *textDocument = o->getObject("textDocument")) {
+    if (const llvm::json::Object *documentSymbol =
+            textDocument->getObject("documentSymbol")) {
+      if (Optional<bool> hierarchicalSupport =
+              documentSymbol->getBoolean("hierarchicalDocumentSymbolSupport"))
+        result.hierarchicalDocumentSymbol = *hierarchicalSupport;
+    }
+  }
+  return true;
+}
+
+//===----------------------------------------------------------------------===//
 // InitializeParams
 //===----------------------------------------------------------------------===//
 
@@ -275,6 +297,7 @@ bool mlir::lsp::fromJSON(const llvm::json::Value &value,
   if (!o)
     return false;
   // We deliberately don't fail if we can't parse individual fields.
+  o.map("capabilities", result.capabilities);
   o.map("trace", result.trace);
   return true;
 }
@@ -495,6 +518,33 @@ llvm::json::Value mlir::lsp::toJSON(const Hover &hover) {
 }
 
 //===----------------------------------------------------------------------===//
+// DocumentSymbol
+//===----------------------------------------------------------------------===//
+
+llvm::json::Value mlir::lsp::toJSON(const DocumentSymbol &symbol) {
+  llvm::json::Object result{{"name", symbol.name},
+                            {"kind", static_cast<int>(symbol.kind)},
+                            {"range", symbol.range},
+                            {"selectionRange", symbol.selectionRange}};
+
+  if (!symbol.detail.empty())
+    result["detail"] = symbol.detail;
+  if (!symbol.children.empty())
+    result["children"] = symbol.children;
+  return std::move(result);
+}
+
+//===----------------------------------------------------------------------===//
+// DocumentSymbolParams
+//===----------------------------------------------------------------------===//
+
+bool mlir::lsp::fromJSON(const llvm::json::Value &value,
+                         DocumentSymbolParams &result, llvm::json::Path path) {
+  llvm::json::ObjectMapper o(value, path);
+  return o && o.map("textDocument", result.textDocument);
+}
+
+//===----------------------------------------------------------------------===//
 // DiagnosticRelatedInformation
 //===----------------------------------------------------------------------===//
 
index d9e9ad1..d0f7b05 100644 (file)
@@ -135,6 +135,20 @@ bool fromJSON(const llvm::json::Value &value, URIForFile &result,
 raw_ostream &operator<<(raw_ostream &os, const URIForFile &value);
 
 //===----------------------------------------------------------------------===//
+// ClientCapabilities
+//===----------------------------------------------------------------------===//
+
+struct ClientCapabilities {
+  /// Client supports hierarchical document symbols.
+  /// textDocument.documentSymbol.hierarchicalDocumentSymbolSupport
+  bool hierarchicalDocumentSymbol = false;
+};
+
+/// Add support for JSON serialization.
+bool fromJSON(const llvm::json::Value &value, ClientCapabilities &result,
+              llvm::json::Path path);
+
+//===----------------------------------------------------------------------===//
 // InitializeParams
 //===----------------------------------------------------------------------===//
 
@@ -149,6 +163,9 @@ bool fromJSON(const llvm::json::Value &value, TraceLevel &result,
               llvm::json::Path path);
 
 struct InitializeParams {
+  /// The capabilities provided by the client (editor or tool).
+  ClientCapabilities capabilities;
+
   /// The initial trace setting. If omitted trace is disabled ('off').
   Optional<TraceLevel> trace;
 };
@@ -224,6 +241,9 @@ bool fromJSON(const llvm::json::Value &value,
 //===----------------------------------------------------------------------===//
 
 struct Position {
+  Position(int line = 0, int character = 0)
+      : line(line), character(character) {}
+
   /// Line position in a document (zero-based).
   int line = 0;
 
@@ -450,6 +470,94 @@ struct Hover {
 llvm::json::Value toJSON(const Hover &hover);
 
 //===----------------------------------------------------------------------===//
+// SymbolKind
+//===----------------------------------------------------------------------===//
+
+enum class SymbolKind {
+  File = 1,
+  Module = 2,
+  Namespace = 3,
+  Package = 4,
+  Class = 5,
+  Method = 6,
+  Property = 7,
+  Field = 8,
+  Constructor = 9,
+  Enum = 10,
+  Interface = 11,
+  Function = 12,
+  Variable = 13,
+  Constant = 14,
+  String = 15,
+  Number = 16,
+  Boolean = 17,
+  Array = 18,
+  Object = 19,
+  Key = 20,
+  Null = 21,
+  EnumMember = 22,
+  Struct = 23,
+  Event = 24,
+  Operator = 25,
+  TypeParameter = 26
+};
+
+//===----------------------------------------------------------------------===//
+// DocumentSymbol
+//===----------------------------------------------------------------------===//
+
+/// Represents programming constructs like variables, classes, interfaces etc.
+/// that appear in a document. Document symbols can be hierarchical and they
+/// have two ranges: one that encloses its definition and one that points to its
+/// most interesting range, e.g. the range of an identifier.
+struct DocumentSymbol {
+  DocumentSymbol() = default;
+  DocumentSymbol(DocumentSymbol &&) = default;
+  DocumentSymbol(const Twine &name, SymbolKind kind, Range range,
+                 Range selectionRange)
+      : name(name.str()), kind(kind), range(range),
+        selectionRange(selectionRange) {}
+
+  /// The name of this symbol.
+  std::string name;
+
+  /// More detail for this symbol, e.g the signature of a function.
+  std::string detail;
+
+  /// The kind of this symbol.
+  SymbolKind kind;
+
+  /// The range enclosing this symbol not including leading/trailing whitespace
+  /// but everything else like comments. This information is typically used to
+  /// determine if the clients cursor is inside the symbol to reveal in the
+  /// symbol in the UI.
+  Range range;
+
+  /// The range that should be selected and revealed when this symbol is being
+  /// picked, e.g the name of a function. Must be contained by the `range`.
+  Range selectionRange;
+
+  /// Children of this symbol, e.g. properties of a class.
+  std::vector<DocumentSymbol> children;
+};
+
+/// Add support for JSON serialization.
+llvm::json::Value toJSON(const DocumentSymbol &symbol);
+
+//===----------------------------------------------------------------------===//
+// DocumentSymbolParams
+//===----------------------------------------------------------------------===//
+
+struct DocumentSymbolParams {
+  // The text document to find symbols in.
+  TextDocumentIdentifier textDocument;
+};
+
+/// Add support for JSON serialization.
+bool fromJSON(const llvm::json::Value &value, DocumentSymbolParams &result,
+              llvm::json::Path path);
+
+//===----------------------------------------------------------------------===//
 // DiagnosticRelatedInformation
 //===----------------------------------------------------------------------===//
 
index 3d9e21d..768fd2e 100644 (file)
@@ -204,6 +204,8 @@ llvm::Error JSONTransport::run(MessageHandler &handler) {
       if (llvm::Expected<llvm::json::Value> doc = llvm::json::parse(json)) {
         if (!handleMessage(std::move(*doc), handler))
           return llvm::Error::success();
+      } else {
+        Logger::error("JSON parse error: {0}", llvm::toString(doc.takeError()));
       }
     }
   }
diff --git a/mlir/test/mlir-lsp-server/document-symbols.test b/mlir/test/mlir-lsp-server/document-symbols.test
new file mode 100644 (file)
index 0000000..17d5d11
--- /dev/null
@@ -0,0 +1,71 @@
+// RUN: mlir-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s
+{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootUri":"test:///workspace","capabilities":{"textDocument":{"documentSymbol":{"hierarchicalDocumentSymbolSupport":true}}},"trace":"off"}}
+// -----
+{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{
+  "uri":"test:///foo.mlir",
+  "languageId":"mlir",
+  "version":1,
+  "text":"module {\nfunc private @foo()\n}"
+}}}
+// -----
+{"jsonrpc":"2.0","id":1,"method":"textDocument/documentSymbol","params":{
+  "textDocument":{"uri":"test:///foo.mlir"}
+}}
+//      CHECK:  "id": 1
+// CHECK-NEXT:  "jsonrpc": "2.0",
+// CHECK-NEXT:  "result": [
+// CHECK-NEXT:    {
+// CHECK-NEXT:      "children": [
+// CHECK-NEXT:        {
+// CHECK-NEXT:          "kind": 12,
+// CHECK-NEXT:          "name": "foo",
+// CHECK-NEXT:          "range": {
+// CHECK-NEXT:            "end": {
+// CHECK-NEXT:              "character": {{.*}},
+// CHECK-NEXT:              "line": {{.*}}
+// CHECK-NEXT:            },
+// CHECK-NEXT:            "start": {
+// CHECK-NEXT:              "character": {{.*}},
+// CHECK-NEXT:              "line": {{.*}}
+// CHECK-NEXT:            }
+// CHECK-NEXT:          },
+// CHECK-NEXT:          "selectionRange": {
+// CHECK-NEXT:            "end": {
+// CHECK-NEXT:              "character": 4,
+// CHECK-NEXT:              "line": {{.*}}
+// CHECK-NEXT:            },
+// CHECK-NEXT:            "start": {
+// CHECK-NEXT:              "character": {{.*}},
+// CHECK-NEXT:              "line": {{.*}}
+// CHECK-NEXT:            }
+// CHECK-NEXT:          }
+// CHECK-NEXT:        }
+// CHECK-NEXT:      ],
+// CHECK-NEXT:      "kind": 3,
+// CHECK-NEXT:      "name": "<module>",
+// CHECK-NEXT:      "range": {
+// CHECK-NEXT:        "end": {
+// CHECK-NEXT:          "character": {{.*}},
+// CHECK-NEXT:          "line": {{.*}}
+// CHECK-NEXT:        },
+// CHECK-NEXT:        "start": {
+// CHECK-NEXT:          "character": {{.*}},
+// CHECK-NEXT:          "line": {{.*}}
+// CHECK-NEXT:        }
+// CHECK-NEXT:      },
+// CHECK-NEXT:      "selectionRange": {
+// CHECK-NEXT:        "end": {
+// CHECK-NEXT:          "character": {{.*}},
+// CHECK-NEXT:          "line": {{.*}}
+// CHECK-NEXT:        },
+// CHECK-NEXT:        "start": {
+// CHECK-NEXT:          "character": {{.*}},
+// CHECK-NEXT:          "line": {{.*}}
+// CHECK-NEXT:        }
+// CHECK-NEXT:      }
+// CHECK-NEXT:    }
+// CHECK-NEXT:  ]
+// -----
+{"jsonrpc":"2.0","id":3,"method":"shutdown"}
+// -----
+{"jsonrpc":"2.0","method":"exit"}
index 261d024..db41a61 100644 (file)
@@ -6,6 +6,7 @@
 // CHECK-NEXT:  "result": {
 // CHECK-NEXT:    "capabilities": {
 // CHECK-NEXT:      "definitionProvider": true,
+// CHECK-NEXT:      "documentSymbolProvider": false,
 // CHECK-NEXT:      "hoverProvider": true,
 // CHECK-NEXT:      "referencesProvider": true,
 // CHECK-NEXT:      "textDocumentSync": {