[mlir] Execute all requested translations in MlirTranslateMain
authorFabian <fabian.tschopp@modular.com>
Sat, 18 Feb 2023 20:31:37 +0000 (21:31 +0100)
committerFabian <fabian.tschopp@modular.com>
Sat, 18 Feb 2023 20:31:59 +0000 (21:31 +0100)
Currently, MlirTranslateMain only executes one of the requested translations, and does not error if multiple are specified. This commit enables translations to be chained in the specified order.

This makes round-trip tests easier, since existing import/export passes can be reused and no combined round-trip passes have to be registered (example: mlir-translate  -serialize-spirv -deserialize-spirv).

Additionally, by leveraging TranslateRegistration with file-to-file TranslateFunctions, generic pre- and post-processing can be added before/after conversion to/from MLIR.

Reviewed By: lattner, Mogball

Differential Revision: https://reviews.llvm.org/D143719

mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
mlir/test/Target/SPIRV/array-two-step-roundtrip.mlir [new file with mode: 0644]

index 02c9128..5e460ed 100644 (file)
@@ -13,6 +13,7 @@
 #include "mlir/IR/Verifier.h"
 #include "mlir/Parser/Parser.h"
 #include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/ToolUtilities.h"
 #include "mlir/Tools/mlir-translate/Translation.h"
 #include "llvm/Support/InitLLVM.h"
@@ -56,9 +57,9 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
   llvm::InitLLVM y(argc, argv);
 
   // Add flags for all the registered translations.
-  llvm::cl::opt<const Translation *, false, TranslationParser>
-      translationRequested("", llvm::cl::desc("Translation to perform"),
-                           llvm::cl::Required);
+  llvm::cl::list<const Translation *, bool, TranslationParser>
+      translationsRequested("", llvm::cl::desc("Translations to perform"),
+                            llvm::cl::Required);
   registerAsmPrinterCLOptions();
   registerMLIRContextCLOptions();
   registerTranslationCLOptions();
@@ -66,7 +67,7 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
 
   std::string errorMessage;
   std::unique_ptr<llvm::MemoryBuffer> input;
-  if (auto inputAlignment = translationRequested->getInputAlignment())
+  if (auto inputAlignment = translationsRequested[0]->getInputAlignment())
     input = openInputFile(inputFilename, *inputAlignment, &errorMessage);
   else
     input = openInputFile(inputFilename, &errorMessage);
@@ -84,23 +85,54 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
   // Processes the memory buffer with a new MLIRContext.
   auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
                            raw_ostream &os) {
-    MLIRContext context;
-    context.allowUnregisteredDialects(allowUnregisteredDialects);
-    context.printOpOnDiagnostic(!verifyDiagnostics);
-    auto sourceMgr = std::make_shared<llvm::SourceMgr>();
-    sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
-
-    if (!verifyDiagnostics) {
-      SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context);
-      return (*translationRequested)(sourceMgr, os, &context);
+    // Temporary buffers for chained translation processing.
+    std::string dataIn;
+    std::string dataOut;
+    LogicalResult result = LogicalResult::success();
+
+    for (size_t i = 0, e = translationsRequested.size(); i < e; ++i) {
+      llvm::raw_ostream *stream;
+      llvm::raw_string_ostream dataStream(dataOut);
+
+      if (i == e - 1) {
+        // Output last translation to output.
+        stream = &os;
+      } else {
+        // Output translation to temporary data buffer.
+        stream = &dataStream;
+      }
+
+      const Translation *translationRequested = translationsRequested[i];
+      MLIRContext context;
+      context.allowUnregisteredDialects(allowUnregisteredDialects);
+      context.printOpOnDiagnostic(!verifyDiagnostics);
+      auto sourceMgr = std::make_shared<llvm::SourceMgr>();
+      sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
+
+      if (verifyDiagnostics) {
+        // In the diagnostic verification flow, we ignore whether the
+        // translation failed (in most cases, it is expected to fail).
+        // Instead, we check if the diagnostics were produced as expected.
+        SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr,
+                                                            &context);
+        (void)(*translationRequested)(sourceMgr, os, &context);
+        result = sourceMgrHandler.verify();
+      } else {
+        SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context);
+        result = (*translationRequested)(sourceMgr, *stream, &context);
+      }
+      if (failed(result))
+        return result;
+
+      if (i < e - 1) {
+        // If there are further translations, create a new buffer with the
+        // output data.
+        dataIn = dataOut;
+        dataOut.clear();
+        ownedBuffer = llvm::MemoryBuffer::getMemBuffer(dataIn);
+      }
     }
-
-    // In the diagnostic verification flow, we ignore whether the translation
-    // failed (in most cases, it is expected to fail). Instead, we check if the
-    // diagnostics were produced as expected.
-    SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context);
-    (void)(*translationRequested)(sourceMgr, os, &context);
-    return sourceMgrHandler.verify();
+    return result;
   };
 
   if (failed(splitAndProcessBuffer(std::move(input), processBuffer,
diff --git a/mlir/test/Target/SPIRV/array-two-step-roundtrip.mlir b/mlir/test/Target/SPIRV/array-two-step-roundtrip.mlir
new file mode 100644 (file)
index 0000000..427b926
--- /dev/null
@@ -0,0 +1,18 @@
+// RUN: mlir-translate -no-implicit-module -split-input-file -serialize-spirv -deserialize-spirv %s | FileCheck %s
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+  spirv.func @array_stride(%arg0 : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32, stride=4>, stride=128>, StorageBuffer>, %arg1 : i32, %arg2 : i32) "None" {
+    // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.array<4 x !spirv.array<4 x f32, stride=4>, stride=128>, StorageBuffer>, i32, i32
+    %2 = spirv.AccessChain %arg0[%arg1, %arg2] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32, stride=4>, stride=128>, StorageBuffer>, i32, i32
+    spirv.Return
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+  // CHECK: spirv.GlobalVariable {{@.*}} : !spirv.ptr<!spirv.rtarray<f32, stride=4>, StorageBuffer>
+  spirv.GlobalVariable @var0 : !spirv.ptr<!spirv.rtarray<f32, stride=4>, StorageBuffer>
+  // CHECK: spirv.GlobalVariable {{@.*}} : !spirv.ptr<!spirv.rtarray<vector<4xf16>>, Input>
+  spirv.GlobalVariable @var1 : !spirv.ptr<!spirv.rtarray<vector<4xf16>>, Input>
+}