From 09fd9ef4f4cb5c7b25bdb9f97d7e92445aec3417 Mon Sep 17 00:00:00 2001 From: Fabian Date: Sat, 18 Feb 2023 21:31:37 +0100 Subject: [PATCH] [mlir] Execute all requested translations in MlirTranslateMain Currently, MlirTranslateMain only executes one of the requested translations, and does not error if multiple are specified. This commit enables translations to be chained in the specified order. This makes round-trip tests easier, since existing import/export passes can be reused and no combined round-trip passes have to be registered (example: mlir-translate -serialize-spirv -deserialize-spirv). Additionally, by leveraging TranslateRegistration with file-to-file TranslateFunctions, generic pre- and post-processing can be added before/after conversion to/from MLIR. Reviewed By: lattner, Mogball Differential Revision: https://reviews.llvm.org/D143719 --- .../lib/Tools/mlir-translate/MlirTranslateMain.cpp | 72 ++++++++++++++++------ .../Target/SPIRV/array-two-step-roundtrip.mlir | 18 ++++++ 2 files changed, 70 insertions(+), 20 deletions(-) create mode 100644 mlir/test/Target/SPIRV/array-two-step-roundtrip.mlir diff --git a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp index 02c9128..5e460ed 100644 --- a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp +++ b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/Verifier.h" #include "mlir/Parser/Parser.h" #include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Support/ToolUtilities.h" #include "mlir/Tools/mlir-translate/Translation.h" #include "llvm/Support/InitLLVM.h" @@ -56,9 +57,9 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv, llvm::InitLLVM y(argc, argv); // Add flags for all the registered translations. - llvm::cl::opt - translationRequested("", llvm::cl::desc("Translation to perform"), - llvm::cl::Required); + llvm::cl::list + translationsRequested("", llvm::cl::desc("Translations to perform"), + llvm::cl::Required); registerAsmPrinterCLOptions(); registerMLIRContextCLOptions(); registerTranslationCLOptions(); @@ -66,7 +67,7 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv, std::string errorMessage; std::unique_ptr input; - if (auto inputAlignment = translationRequested->getInputAlignment()) + if (auto inputAlignment = translationsRequested[0]->getInputAlignment()) input = openInputFile(inputFilename, *inputAlignment, &errorMessage); else input = openInputFile(inputFilename, &errorMessage); @@ -84,23 +85,54 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv, // Processes the memory buffer with a new MLIRContext. auto processBuffer = [&](std::unique_ptr ownedBuffer, raw_ostream &os) { - MLIRContext context; - context.allowUnregisteredDialects(allowUnregisteredDialects); - context.printOpOnDiagnostic(!verifyDiagnostics); - auto sourceMgr = std::make_shared(); - sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); - - if (!verifyDiagnostics) { - SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context); - return (*translationRequested)(sourceMgr, os, &context); + // Temporary buffers for chained translation processing. + std::string dataIn; + std::string dataOut; + LogicalResult result = LogicalResult::success(); + + for (size_t i = 0, e = translationsRequested.size(); i < e; ++i) { + llvm::raw_ostream *stream; + llvm::raw_string_ostream dataStream(dataOut); + + if (i == e - 1) { + // Output last translation to output. + stream = &os; + } else { + // Output translation to temporary data buffer. + stream = &dataStream; + } + + const Translation *translationRequested = translationsRequested[i]; + MLIRContext context; + context.allowUnregisteredDialects(allowUnregisteredDialects); + context.printOpOnDiagnostic(!verifyDiagnostics); + auto sourceMgr = std::make_shared(); + sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); + + if (verifyDiagnostics) { + // In the diagnostic verification flow, we ignore whether the + // translation failed (in most cases, it is expected to fail). + // Instead, we check if the diagnostics were produced as expected. + SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, + &context); + (void)(*translationRequested)(sourceMgr, os, &context); + result = sourceMgrHandler.verify(); + } else { + SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context); + result = (*translationRequested)(sourceMgr, *stream, &context); + } + if (failed(result)) + return result; + + if (i < e - 1) { + // If there are further translations, create a new buffer with the + // output data. + dataIn = dataOut; + dataOut.clear(); + ownedBuffer = llvm::MemoryBuffer::getMemBuffer(dataIn); + } } - - // In the diagnostic verification flow, we ignore whether the translation - // failed (in most cases, it is expected to fail). Instead, we check if the - // diagnostics were produced as expected. - SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context); - (void)(*translationRequested)(sourceMgr, os, &context); - return sourceMgrHandler.verify(); + return result; }; if (failed(splitAndProcessBuffer(std::move(input), processBuffer, diff --git a/mlir/test/Target/SPIRV/array-two-step-roundtrip.mlir b/mlir/test/Target/SPIRV/array-two-step-roundtrip.mlir new file mode 100644 index 0000000..427b926 --- /dev/null +++ b/mlir/test/Target/SPIRV/array-two-step-roundtrip.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-translate -no-implicit-module -split-input-file -serialize-spirv -deserialize-spirv %s | FileCheck %s + +spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @array_stride(%arg0 : !spirv.ptr, stride=128>, StorageBuffer>, %arg1 : i32, %arg2 : i32) "None" { + // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr, stride=128>, StorageBuffer>, i32, i32 + %2 = spirv.AccessChain %arg0[%arg1, %arg2] : !spirv.ptr, stride=128>, StorageBuffer>, i32, i32 + spirv.Return + } +} + +// ----- + +spirv.module Logical GLSL450 requires #spirv.vce { + // CHECK: spirv.GlobalVariable {{@.*}} : !spirv.ptr, StorageBuffer> + spirv.GlobalVariable @var0 : !spirv.ptr, StorageBuffer> + // CHECK: spirv.GlobalVariable {{@.*}} : !spirv.ptr>, Input> + spirv.GlobalVariable @var1 : !spirv.ptr>, Input> +} -- 2.7.4