[mlir:PDLL-LSP] Add a custom LSP command for viewing the output of PDLL
authorRiver Riddle <riddleriver@gmail.com>
Sat, 30 Apr 2022 07:29:49 +0000 (00:29 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 31 May 2022 00:35:34 +0000 (17:35 -0700)
This commit adds a new PDLL specific LSP command, pdll.viewOutput, that
allows for viewing the intermediate outputs of a given PDLL file. The available
intermediate forms currently mirror those in mlir-pdll, namely: AST, MLIR, CPP.
This is extremely useful for a developer of PDLL, as it simplifies various testing,
and is also quite useful for users as they can easily view what is actually being
generated for their PDLL files.

This new command is added to the vscode client, and is available in the right
client context menu of PDLL files, or via the vscode command palette.

Differential Revision: https://reviews.llvm.org/D124783

14 files changed:
mlir/lib/Tools/lsp-server-support/CompilationDatabase.cpp
mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt
mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp
mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h
mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp [new file with mode: 0644]
mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h [new file with mode: 0644]
mlir/test/mlir-pdll-lsp-server/view-output.test [new file with mode: 0644]
mlir/utils/vscode/package.json
mlir/utils/vscode/src/PDLL/commands/viewOutput.ts [new file with mode: 0644]
mlir/utils/vscode/src/PDLL/pdll.ts [new file with mode: 0644]
mlir/utils/vscode/src/command.ts [new file with mode: 0644]
mlir/utils/vscode/src/extension.ts
mlir/utils/vscode/src/mlirContext.ts

index 31a6a30..229e50a 100644 (file)
@@ -8,7 +8,7 @@
 
 #include "CompilationDatabase.h"
 #include "../lsp-server-support/Logging.h"
-#include "../lsp-server-support/Protocol.h"
+#include "Protocol.h"
 #include "mlir/Support/FileUtilities.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringRef.h"
index f9ff547..bf25b7e 100644 (file)
@@ -1,12 +1,14 @@
 llvm_add_library(MLIRPdllLspServerLib
   LSPServer.cpp
   PDLLServer.cpp
+  Protocol.cpp
   MlirPdllLspServerMain.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-pdll-lsp-server
 
   LINK_LIBS PUBLIC
+  MLIRPDLLCodeGen
   MLIRPDLLParser
   MLIRLspServerSupportLib
   )
index 5f4967f..ff955a7 100644 (file)
@@ -9,9 +9,9 @@
 #include "LSPServer.h"
 
 #include "../lsp-server-support/Logging.h"
-#include "../lsp-server-support/Protocol.h"
 #include "../lsp-server-support/Transport.h"
 #include "PDLLServer.h"
+#include "Protocol.h"
 #include "llvm/ADT/FunctionExtras.h"
 #include "llvm/ADT/StringMap.h"
 
@@ -83,6 +83,12 @@ struct LSPServer {
                        Callback<SignatureHelp> reply);
 
   //===--------------------------------------------------------------------===//
+  // PDLL View Output
+
+  void onPDLLViewOutput(const PDLLViewOutputParams &params,
+                        Callback<Optional<PDLLViewOutputResult>> reply);
+
+  //===--------------------------------------------------------------------===//
   // Fields
   //===--------------------------------------------------------------------===//
 
@@ -249,6 +255,15 @@ void LSPServer::onSignatureHelp(const TextDocumentPositionParams &params,
 }
 
 //===----------------------------------------------------------------------===//
+// PDLL ViewOutput
+
+void LSPServer::onPDLLViewOutput(
+    const PDLLViewOutputParams &params,
+    Callback<Optional<PDLLViewOutputResult>> reply) {
+  reply(server.getPDLLViewOutput(params.uri, params.kind));
+}
+
+//===----------------------------------------------------------------------===//
 // Entry Point
 //===----------------------------------------------------------------------===//
 
@@ -296,6 +311,10 @@ LogicalResult mlir::lsp::runPdllLSPServer(PDLLServer &server,
   messageHandler.method("textDocument/signatureHelp", &lspServer,
                         &LSPServer::onSignatureHelp);
 
+  // PDLL ViewOutput
+  messageHandler.method("pdll/viewOutput", &lspServer,
+                        &LSPServer::onPDLLViewOutput);
+
   // Diagnostics
   lspServer.publishDiagnostics =
       messageHandler.outgoingNotification<PublishDiagnosticsParams>(
index cbfd421..215a0e6 100644 (file)
 
 #include "../lsp-server-support/CompilationDatabase.h"
 #include "../lsp-server-support/Logging.h"
-#include "../lsp-server-support/Protocol.h"
 #include "../lsp-server-support/SourceMgrUtils.h"
+#include "Protocol.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/Tools/PDLL/AST/Context.h"
 #include "mlir/Tools/PDLL/AST/Nodes.h"
 #include "mlir/Tools/PDLL/AST/Types.h"
+#include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
+#include "mlir/Tools/PDLL/CodeGen/MLIRGen.h"
 #include "mlir/Tools/PDLL/ODS/Constraint.h"
 #include "mlir/Tools/PDLL/ODS/Context.h"
 #include "mlir/Tools/PDLL/ODS/Dialect.h"
@@ -305,6 +308,12 @@ struct PDLDocument {
                                       const lsp::Position &helpPos);
 
   //===--------------------------------------------------------------------===//
+  // PDLL ViewOutput
+  //===--------------------------------------------------------------------===//
+
+  void getPDLLViewOutput(raw_ostream &os, lsp::PDLLViewOutputKind kind);
+
+  //===--------------------------------------------------------------------===//
   // Fields
   //===--------------------------------------------------------------------===//
 
@@ -1086,6 +1095,39 @@ lsp::SignatureHelp PDLDocument::getSignatureHelp(const lsp::URIForFile &uri,
 }
 
 //===----------------------------------------------------------------------===//
+// PDLL ViewOutput
+//===----------------------------------------------------------------------===//
+
+void PDLDocument::getPDLLViewOutput(raw_ostream &os,
+                                    lsp::PDLLViewOutputKind kind) {
+  if (failed(astModule))
+    return;
+  if (kind == lsp::PDLLViewOutputKind::AST) {
+    (*astModule)->print(os);
+    return;
+  }
+
+  // Generate the MLIR for the ast module. We also capture diagnostics here to
+  // show to the user, which may be useful if PDLL isn't capturing constraints
+  // expected by PDL.
+  MLIRContext mlirContext;
+  SourceMgrDiagnosticHandler diagHandler(sourceMgr, &mlirContext, os);
+  OwningOpRef<ModuleOp> pdlModule =
+      codegenPDLLToMLIR(&mlirContext, astContext, sourceMgr, **astModule);
+  if (!pdlModule)
+    return;
+  if (kind == lsp::PDLLViewOutputKind::MLIR) {
+    pdlModule->print(os, OpPrintingFlags().enableDebugInfo());
+    return;
+  }
+
+  // Otherwise, generate the output for C++.
+  assert(kind == lsp::PDLLViewOutputKind::CPP &&
+         "unexpected PDLLViewOutputKind");
+  codegenPDLLToCPP(**astModule, *pdlModule, os);
+}
+
+//===----------------------------------------------------------------------===//
 // PDLTextFileChunk
 //===----------------------------------------------------------------------===//
 
@@ -1148,6 +1190,7 @@ public:
                                         lsp::Position completePos);
   lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri,
                                       lsp::Position helpPos);
+  lsp::PDLLViewOutputResult getPDLLViewOutput(lsp::PDLLViewOutputKind kind);
 
 private:
   /// Find the PDL document that contains the given position, and update the
@@ -1321,6 +1364,21 @@ lsp::SignatureHelp PDLTextFile::getSignatureHelp(const lsp::URIForFile &uri,
   return getChunkFor(helpPos).document.getSignatureHelp(uri, helpPos);
 }
 
+lsp::PDLLViewOutputResult
+PDLTextFile::getPDLLViewOutput(lsp::PDLLViewOutputKind kind) {
+  lsp::PDLLViewOutputResult result;
+  {
+    llvm::raw_string_ostream outputOS(result.output);
+    llvm::interleave(
+        llvm::make_pointee_range(chunks),
+        [&](PDLTextFileChunk &chunk) {
+          chunk.document.getPDLLViewOutput(outputOS, kind);
+        },
+        [&] { outputOS << "\n// -----\n\n"; });
+  }
+  return result;
+}
+
 PDLTextFileChunk &PDLTextFile::getChunkFor(lsp::Position &pos) {
   if (chunks.size() == 1)
     return *chunks.front();
@@ -1439,3 +1497,12 @@ lsp::SignatureHelp lsp::PDLLServer::getSignatureHelp(const URIForFile &uri,
     return fileIt->second->getSignatureHelp(uri, helpPos);
   return SignatureHelp();
 }
+
+Optional<lsp::PDLLViewOutputResult>
+lsp::PDLLServer::getPDLLViewOutput(const URIForFile &uri,
+                                   PDLLViewOutputKind kind) {
+  auto fileIt = impl->files.find(uri.file());
+  if (fileIt != impl->files.end())
+    return fileIt->second->getPDLLViewOutput(kind);
+  return llvm::None;
+}
index 12716f4..0fecc35 100644 (file)
@@ -18,6 +18,8 @@ namespace mlir {
 namespace lsp {
 struct Diagnostic;
 class CompilationDatabase;
+struct PDLLViewOutputResult;
+enum class PDLLViewOutputKind;
 struct CompletionList;
 struct DocumentLink;
 struct DocumentSymbol;
@@ -88,6 +90,11 @@ public:
   SignatureHelp getSignatureHelp(const URIForFile &uri,
                                  const Position &helpPos);
 
+  /// Get the output of the given PDLL file, or None if there is no valid
+  /// output.
+  Optional<PDLLViewOutputResult> getPDLLViewOutput(const URIForFile &uri,
+                                                   PDLLViewOutputKind kind);
+
 private:
   struct Impl;
   std::unique_ptr<Impl> impl;
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp
new file mode 100644 (file)
index 0000000..8b767ee
--- /dev/null
@@ -0,0 +1,77 @@
+//===--- Protocol.cpp - Language Server Protocol Implementation -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the serialization code for the PDLL specific LSP structs.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Protocol.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/JSON.h"
+#include "llvm/Support/Path.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::lsp;
+
+// Helper that doesn't treat `null` and absent fields as failures.
+template <typename T>
+static bool mapOptOrNull(const llvm::json::Value &params,
+                         llvm::StringLiteral prop, T &out,
+                         llvm::json::Path path) {
+  const llvm::json::Object *o = params.getAsObject();
+  assert(o);
+
+  // Field is missing or null.
+  auto *v = o->get(prop);
+  if (!v || v->getAsNull().hasValue())
+    return true;
+  return fromJSON(*v, out, path.field(prop));
+}
+
+//===----------------------------------------------------------------------===//
+// PDLLViewOutputParams
+//===----------------------------------------------------------------------===//
+
+bool mlir::lsp::fromJSON(const llvm::json::Value &value,
+                         PDLLViewOutputKind &result, llvm::json::Path path) {
+  if (Optional<StringRef> str = value.getAsString()) {
+    if (*str == "ast") {
+      result = PDLLViewOutputKind::AST;
+      return true;
+    }
+    if (*str == "mlir") {
+      result = PDLLViewOutputKind::MLIR;
+      return true;
+    }
+    if (*str == "cpp") {
+      result = PDLLViewOutputKind::CPP;
+      return true;
+    }
+  }
+  return false;
+}
+
+bool mlir::lsp::fromJSON(const llvm::json::Value &value,
+                         PDLLViewOutputParams &result, llvm::json::Path path) {
+  llvm::json::ObjectMapper o(value, path);
+  return o && o.map("uri", result.uri) && o.map("kind", result.kind);
+}
+
+//===----------------------------------------------------------------------===//
+// PDLLViewOutputResult
+//===----------------------------------------------------------------------===//
+
+llvm::json::Value mlir::lsp::toJSON(const PDLLViewOutputResult &value) {
+  return llvm::json::Object{{"output", value.output}};
+}
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h
new file mode 100644 (file)
index 0000000..3de2ae0
--- /dev/null
@@ -0,0 +1,69 @@
+//===--- Protocol.h - Language Server Protocol Implementation ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains structs for LSP commands that are specific to the PDLL
+// server.
+//
+// Each struct has a toJSON and fromJSON function, that converts between
+// the struct and a JSON representation. (See JSON.h)
+//
+// Some structs also have operator<< serialization. This is for debugging and
+// tests, and is not generally machine-readable.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LIB_MLIR_TOOLS_MLIRPDLLLSPSERVER_PROTOCOL_H_
+#define LIB_MLIR_TOOLS_MLIRPDLLLSPSERVER_PROTOCOL_H_
+
+#include "../lsp-server-support/Protocol.h"
+
+namespace mlir {
+namespace lsp {
+//===----------------------------------------------------------------------===//
+// PDLLViewOutputParams
+//===----------------------------------------------------------------------===//
+
+/// The type of output to view from PDLL.
+enum class PDLLViewOutputKind {
+  AST,
+  MLIR,
+  CPP,
+};
+
+/// Represents the parameters used when viewing the output of a PDLL file.
+struct PDLLViewOutputParams {
+  /// The URI of the document to view the output of.
+  URIForFile uri;
+
+  /// The kind of output to generate.
+  PDLLViewOutputKind kind;
+};
+
+/// Add support for JSON serialization.
+bool fromJSON(const llvm::json::Value &value, PDLLViewOutputKind &result,
+              llvm::json::Path path);
+bool fromJSON(const llvm::json::Value &value, PDLLViewOutputParams &result,
+              llvm::json::Path path);
+
+//===----------------------------------------------------------------------===//
+// PDLLViewOutputResult
+//===----------------------------------------------------------------------===//
+
+/// Represents the result of viewing the output of a PDLL file.
+struct PDLLViewOutputResult {
+  /// The string representation of the output.
+  std::string output;
+};
+
+/// Add support for JSON serialization.
+llvm::json::Value toJSON(const PDLLViewOutputResult &value);
+
+} // namespace lsp
+} // namespace mlir
+
+#endif
diff --git a/mlir/test/mlir-pdll-lsp-server/view-output.test b/mlir/test/mlir-pdll-lsp-server/view-output.test
new file mode 100644 (file)
index 0000000..c45ac2e
--- /dev/null
@@ -0,0 +1,43 @@
+// RUN: mlir-pdll-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s
+{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"pdll","capabilities":{},"trace":"off"}}
+// -----
+{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{
+  "uri":"test:///foo.pdll",
+  "languageId":"pdll",
+  "version":1,
+  "text":"Pattern TestPat => erase op<test.op>;"
+}}}
+// -----
+{"jsonrpc":"2.0","id":1,"method":"pdll/viewOutput","params":{
+  "uri":"test:///foo.pdll",
+  "kind":"ast"
+}}
+//      CHECK:  "id": 1
+// CHECK-NEXT:  "jsonrpc": "2.0",
+// CHECK-NEXT:  "result": {
+// CHECK-NEXT:    "output": "-Module{{.*}}PatternDecl{{.*}}Name<TestPat>{{.*}}\n"
+// CHECK-NEXT:  }
+// -----
+{"jsonrpc":"2.0","id":2,"method":"pdll/viewOutput","params":{
+  "uri":"test:///foo.pdll",
+  "kind":"mlir"
+}}
+//      CHECK:  "id": 2
+// CHECK-NEXT:  "jsonrpc": "2.0",
+// CHECK-NEXT:  "result": {
+// CHECK-NEXT:    "output": "module {\n  pdl.pattern @TestPat {{.*}}\n"
+// CHECK-NEXT:  }
+// -----
+{"jsonrpc":"2.0","id":3,"method":"pdll/viewOutput","params":{
+  "uri":"test:///foo.pdll",
+  "kind":"cpp"
+}}
+//      CHECK:  "id": 3
+// CHECK-NEXT:  "jsonrpc": "2.0",
+// CHECK-NEXT:  "result": {
+// CHECK-NEXT:    "output": "{{.*}}struct TestPat : ::mlir::PDLPatternModule{{.*}}\n"
+// CHECK-NEXT:  }
+// -----
+{"jsonrpc":"2.0","id":3,"method":"shutdown"}
+// -----
+{"jsonrpc":"2.0","method":"exit"}
index cca2ba1..d0a0493 100644 (file)
       {
         "command": "mlir.restart",
         "title": "mlir: Restart language server"
+      },
+      {
+        "command": "mlir.viewPDLLOutput",
+        "title": "mlir-pdll: View PDLL output"
       }
-    ]
+    ],
+    "menus": {
+      "editor/context": [
+        {
+          "command": "mlir.viewPDLLOutput",
+          "group": "z_commands",
+          "when": "editorLangId == pdll"
+        }
+      ]
+    }
   }
 }
diff --git a/mlir/utils/vscode/src/PDLL/commands/viewOutput.ts b/mlir/utils/vscode/src/PDLL/commands/viewOutput.ts
new file mode 100644 (file)
index 0000000..4a666a0
--- /dev/null
@@ -0,0 +1,66 @@
+import * as vscode from 'vscode'
+
+import {Command} from '../../command';
+import {MLIRContext} from '../../mlirContext';
+
+/**
+ * The parameters to the pdll/viewOutput command. These parameters are:
+ * - `uri`: The URI of the file to view.
+ * - `kind`: The kind of the output to generate.
+ */
+type ViewOutputParams = Partial<{uri : string, kind : string}>;
+
+/**
+ * The output of the commands:
+ * - `output`: The output string of the command, e.g. a .mlir PDL string.
+ */
+type ViewOutputResult = Partial<{output : string}>;
+
+/**
+ * A command that displays the output of the current PDLL document.
+ */
+export class ViewPDLLCommand extends Command {
+  constructor(context: MLIRContext) { super('mlir.viewPDLLOutput', context); }
+
+  async execute() {
+    const editor = vscode.window.activeTextEditor;
+    if (editor.document.languageId != 'pdll')
+      return;
+
+    // Check to see if a language client is active for this document.
+    const workspaceFolder =
+        vscode.workspace.getWorkspaceFolder(editor.document.uri);
+    const pdllClient = this.context.getLanguageClient(workspaceFolder, "pdll");
+    if (!pdllClient) {
+      return;
+    }
+
+    // Ask the user for the desired output type.
+    const outputType =
+        await vscode.window.showQuickPick([ 'ast', 'mlir', 'cpp' ]);
+    if (!outputType) {
+      return;
+    }
+
+    // If we have the language client, ask it to try compiling the document.
+    let outputParams: ViewOutputParams = {
+      uri : editor.document.uri.toString(),
+      kind : outputType,
+    };
+    const result: ViewOutputResult|undefined =
+        await pdllClient.sendRequest('pdll/viewOutput', outputParams);
+    if (!result || result.output.length === 0) {
+      return;
+    }
+
+    // Display the output in a new editor.
+    let outputFileType = 'plaintext';
+    if (outputType == 'mlir') {
+      outputFileType = 'mlir';
+    } else if (outputType == 'cpp') {
+      outputFileType = 'cpp';
+    }
+    await vscode.workspace.openTextDocument(
+        {language : outputFileType, content : result.output});
+  }
+}
diff --git a/mlir/utils/vscode/src/PDLL/pdll.ts b/mlir/utils/vscode/src/PDLL/pdll.ts
new file mode 100644 (file)
index 0000000..bc65837
--- /dev/null
@@ -0,0 +1,12 @@
+import * as vscode from 'vscode';
+
+import {MLIRContext} from '../mlirContext';
+import {ViewPDLLCommand} from './commands/viewOutput';
+
+/**
+ *  Register the necessary context and commands for PDLL.
+ */
+export function registerPDLLCommands(context: vscode.ExtensionContext,
+                                     mlirContext: MLIRContext) {
+  context.subscriptions.push(new ViewPDLLCommand(mlirContext));
+}
diff --git a/mlir/utils/vscode/src/command.ts b/mlir/utils/vscode/src/command.ts
new file mode 100644 (file)
index 0000000..4623a5b
--- /dev/null
@@ -0,0 +1,25 @@
+import * as vscode from 'vscode';
+import {MLIRContext} from './mlirContext';
+
+/**
+ * This class represents a base vscode command. It handles all of the necessary
+ * command registration and disposal boilerplate.
+ */
+export abstract class Command extends vscode.Disposable {
+  private disposable: vscode.Disposable;
+  protected context: MLIRContext;
+
+  constructor(command: string, context: MLIRContext) {
+    super(() => this.dispose());
+    this.disposable =
+        vscode.commands.registerCommand(command, this.execute, this);
+    this.context = context;
+  }
+
+  dispose() { this.disposable && this.disposable.dispose(); }
+
+  /**
+   * The function executed when this command is invoked.
+   */
+  abstract execute(...args: any[]): any;
+}
index 7d2d4a7..72c754d 100644 (file)
@@ -1,6 +1,7 @@
 import * as vscode from 'vscode';
 
 import {MLIRContext} from './mlirContext';
+import {registerPDLLCommands} from './PDLL/pdll';
 
 /**
  *  This method is called when the extension is activated. The extension is
@@ -20,6 +21,7 @@ export function activate(context: vscode.ExtensionContext) {
         mlirContext.dispose();
         await mlirContext.activate(outputChannel);
       }));
+  registerPDLLCommands(context, mlirContext);
 
   mlirContext.activate(outputChannel);
 }
index c0c2b53..859d616 100644 (file)
@@ -356,6 +356,21 @@ export class MLIRContext implements vscode.Disposable {
     return this.resolvePath(serverPath, defaultPath, workspaceFolder);
   }
 
+  /**
+   * Return the language client for the given language and workspace folder, or
+   * null if no client is active.
+   */
+  getLanguageClient(workspaceFolder: vscode.WorkspaceFolder,
+                    languageName: string): vscodelc.LanguageClient {
+    let workspaceFolderStr =
+        workspaceFolder ? workspaceFolder.uri.toString() : "";
+    let folderContext = this.workspaceFolders.get(workspaceFolderStr);
+    if (!folderContext) {
+      return null;
+    }
+    return folderContext.clients.get(languageName);
+  }
+
   dispose() {
     this.subscriptions.forEach((d) => { d.dispose(); });
     this.subscriptions = [];