Change MLIR translation functions signature
authorLei Zhang <antiagainst@google.com>
Tue, 17 Sep 2019 19:16:05 +0000 (12:16 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 17 Sep 2019 19:16:45 +0000 (12:16 -0700)
This CL changes translation functions to take MemoryBuffer
as input and raw_ostream as output. It is generally better to
avoid handling files directly in a library (unless the library
is specifically for file manipulation) and we can unify all
file handling to the mlir-translate binary itself.

PiperOrigin-RevId: 269625911

mlir/include/mlir/Support/TranslateClParser.h
mlir/include/mlir/Translation.h
mlir/lib/Dialect/SPIRV/Serialization/ConvertFromBinary.cpp
mlir/lib/Dialect/SPIRV/Serialization/ConvertToBinary.cpp
mlir/lib/Support/TranslateClParser.cpp
mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
mlir/tools/mlir-translate/mlir-translate.cpp

index d81dd83..ccd4fb9 100644 (file)
@@ -23,6 +23,7 @@
 #define MLIR_SUPPORT_TRANSLATE_CL_PARSER_H_
 
 #include "mlir/Support/LLVM.h"
+#include "mlir/Translation.h"
 #include "llvm/Support/CommandLine.h"
 #include <functional>
 
@@ -31,10 +32,6 @@ namespace mlir {
 struct LogicalResult;
 class MLIRContext;
 
-/// Common interface for source-to-source translation functions.
-using TranslateFunction = std::function<LogicalResult(
-    StringRef inputFilename, StringRef outputFilename, MLIRContext *)>;
-
 /// Custom parser for TranslateFunction.
 /// Wraps TranslateToMLIRFunctions and TranslateFromMLIRFunctions into
 /// TranslateFunctions before registering them as options.
index b0cb930..0223245 100644 (file)
 #define MLIR_TRANSLATION_H
 
 #include "llvm/ADT/StringMap.h"
-#include "llvm/ADT/StringRef.h"
+
+#include <memory>
+
+namespace llvm {
+class MemoryBuffer;
+class StringRef;
+} // namespace llvm
 
 namespace mlir {
 struct LogicalResult;
@@ -30,15 +36,24 @@ class MLIRContext;
 class ModuleOp;
 class OwningModuleRef;
 
-/// Interface of the function that translates a file to MLIR.  The
-/// implementation should create a new MLIR ModuleOp in the given context and
-/// return a pointer to it, or a nullptr in case of any error.
-using TranslateToMLIRFunction =
-    std::function<OwningModuleRef(llvm::StringRef, MLIRContext *)>;
+/// Interface of the function that translates a source file held by the given
+/// MemoryBuffer to MLIR. The implementation should create a new MLIR ModuleOp
+/// in the given context and return a pointer to it, or a nullptr in case of any
+/// error.
+using TranslateToMLIRFunction = std::function<OwningModuleRef(
+    std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext *)>;
+
 /// Interface of the function that translates MLIR to a different format and
-/// outputs the result to a file. It is allowed to modify the module.
+/// outputs the result to a stream. It is allowed to modify the module.
 using TranslateFromMLIRFunction =
-    std::function<LogicalResult(ModuleOp, llvm::StringRef)>;
+    std::function<LogicalResult(ModuleOp, llvm::raw_ostream &output)>;
+
+/// Interface of the function that performs file-to-file translation involving
+/// MLIR. The input file is held in the given MemoryBuffer; the output file
+/// should be written to the given raw_ostream.
+using TranslateFunction =
+    std::function<LogicalResult(std::unique_ptr<llvm::MemoryBuffer> input,
+                                llvm::raw_ostream &output, MLIRContext *)>;
 
 /// Use Translate[To|From]MLIRRegistration as a global initialiser that
 /// registers a function and associates it with name. This requires that a
index 66b178b..37d15f4 100644 (file)
@@ -35,20 +35,13 @@ using namespace mlir;
 
 // Deserializes the SPIR-V binary module stored in the file named as
 // `inputFilename` and returns a module containing the SPIR-V module.
-OwningModuleRef deserializeModule(llvm::StringRef inputFilename,
+OwningModuleRef deserializeModule(std::unique_ptr<llvm::MemoryBuffer> input,
                                   MLIRContext *context) {
   Builder builder(context);
 
-  std::string errorMessage;
-  auto file = openInputFile(inputFilename, &errorMessage);
-  if (!file) {
-    emitError(UnknownLoc::get(context), errorMessage);
-    return {};
-  }
-
   // Make sure the input stream can be treated as a stream of SPIR-V words
-  auto start = file->getBufferStart();
-  auto size = file->getBufferSize();
+  auto start = input->getBufferStart();
+  auto size = input->getBufferSize();
   if (size % sizeof(uint32_t) != 0) {
     emitError(UnknownLoc::get(context))
         << "SPIR-V binary module must contain integral number of 32-bit words";
@@ -62,15 +55,15 @@ OwningModuleRef deserializeModule(llvm::StringRef inputFilename,
   if (!spirvModule)
     return {};
 
-  OwningModuleRef module(ModuleOp::create(
-      FileLineColLoc::get(inputFilename, /*line=*/0, /*column=*/0, context)));
+  OwningModuleRef module(ModuleOp::create(FileLineColLoc::get(
+      input->getBufferIdentifier(), /*line=*/0, /*column=*/0, context)));
   module->getBody()->push_front(spirvModule->getOperation());
 
   return module;
 }
 
-static TranslateToMLIRRegistration
-    registration("deserialize-spirv",
-                 [](StringRef inputFilename, MLIRContext *context) {
-                   return deserializeModule(inputFilename, context);
-                 });
+static TranslateToMLIRRegistration registration(
+    "deserialize-spirv",
+    [](std::unique_ptr<llvm::MemoryBuffer> input, MLIRContext *context) {
+      return deserializeModule(std::move(input), context);
+    });
index 8267e6b..caf8323 100644 (file)
@@ -24,7 +24,6 @@
 #include "mlir/Dialect/SPIRV/Serialization.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Module.h"
-#include "mlir/Support/FileUtilities.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Translation.h"
 #include "llvm/ADT/StringRef.h"
@@ -32,7 +31,7 @@
 
 using namespace mlir;
 
-LogicalResult serializeModule(ModuleOp module, StringRef outputFilename) {
+LogicalResult serializeModule(ModuleOp module, llvm::raw_ostream &output) {
   if (!module)
     return failure();
 
@@ -51,19 +50,14 @@ LogicalResult serializeModule(ModuleOp module, StringRef outputFilename) {
   if (failed(result))
     return failure();
 
-  auto file = openOutputFile(outputFilename);
-  if (!file)
-    return failure();
-
-  file->os().write(reinterpret_cast<char *>(binary.data()),
-                   binary.size() * sizeof(uint32_t));
-  file->keep();
+  output.write(reinterpret_cast<char *>(binary.data()),
+               binary.size() * sizeof(uint32_t));
 
   return mlir::success();
 }
 
 static TranslateFromMLIRRegistration
     registration("serialize-spirv",
-                 [](ModuleOp module, StringRef outputFilename) {
-                   return serializeModule(module, outputFilename);
+                 [](ModuleOp module, llvm::raw_ostream &output) {
+                   return serializeModule(module, output);
                  });
index 8a7367f..b9535a0 100644 (file)
@@ -25,7 +25,6 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
 #include "mlir/Parser.h"
-#include "mlir/Support/FileUtilities.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Translation.h"
 #include "llvm/Support/CommandLine.h"
@@ -38,15 +37,10 @@ using namespace mlir;
 // Storage for the translation function wrappers that survive the parser.
 static llvm::SmallVector<TranslateFunction, 16> wrapperStorage;
 
-static LogicalResult printMLIROutput(ModuleOp module,
-                                     llvm::StringRef outputFilename) {
+static LogicalResult printMLIROutput(ModuleOp module, llvm::raw_ostream &os) {
   if (failed(verify(module)))
     return failure();
-  auto file = openOutputFile(outputFilename);
-  if (!file)
-    return failure();
-  module.print(file->os());
-  file->keep();
+  module.print(os);
   return success();
 }
 
@@ -60,14 +54,14 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt)
   wrapperStorage.reserve(toMLIRRegistry.size() + fromMLIRRegistry.size());
   for (const auto &kv : toMLIRRegistry) {
     TranslateToMLIRFunction function = kv.second;
-    TranslateFunction wrapper = [function](StringRef inputFilename,
-                                           StringRef outputFilename,
-                                           MLIRContext *context) {
-      OwningModuleRef module = function(inputFilename, context);
-      if (!module)
-        return failure();
-      return printMLIROutput(*module, outputFilename);
-    };
+    TranslateFunction wrapper =
+        [function](std::unique_ptr<llvm::MemoryBuffer> input,
+                   llvm::raw_ostream &output, MLIRContext *context) {
+          OwningModuleRef module = function(std::move(input), context);
+          if (!module)
+            return failure();
+          return printMLIROutput(*module, output);
+        };
     wrapperStorage.emplace_back(std::move(wrapper));
 
     addLiteralOption(kv.first(), &wrapperStorage.back(), kv.first());
@@ -75,17 +69,18 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt)
 
   for (const auto &kv : fromMLIRRegistry) {
     TranslateFromMLIRFunction function = kv.second;
-    TranslateFunction wrapper = [function](StringRef inputFilename,
-                                           StringRef outputFilename,
-                                           MLIRContext *context) {
-      llvm::SourceMgr sourceMgr;
-      SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
-      auto module =
-          OwningModuleRef(parseSourceFile(inputFilename, sourceMgr, context));
-      if (!module)
-        return failure();
-      return function(module.get(), outputFilename);
-    };
+    TranslateFunction wrapper =
+        [function](std::unique_ptr<llvm::MemoryBuffer> input,
+                   llvm::raw_ostream &output, MLIRContext *context) {
+          llvm::SourceMgr sourceMgr;
+          sourceMgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
+          SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
+
+          auto module = OwningModuleRef(parseSourceFile(sourceMgr, context));
+          if (!module)
+            return failure();
+          return function(module.get(), output);
+        };
     wrapperStorage.emplace_back(std::move(wrapper));
 
     addLiteralOption(kv.first(), &wrapperStorage.back(), kv.first());
index 0ba1581..b3bea0f 100644 (file)
@@ -21,7 +21,6 @@
 
 #include "mlir/Target/LLVMIR.h"
 
-#include "mlir/Support/FileUtilities.h"
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
 #include "mlir/Translation.h"
 
@@ -36,19 +35,11 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToLLVMIR(ModuleOp m) {
 }
 
 static TranslateFromMLIRRegistration registration(
-    "mlir-to-llvmir", [](ModuleOp module, llvm::StringRef outputFilename) {
-      if (!module)
-        return failure();
-
+    "mlir-to-llvmir", [](ModuleOp module, llvm::raw_ostream &output) {
       auto llvmModule = LLVM::ModuleTranslation::translateModule<>(module);
       if (!llvmModule)
         return failure();
 
-      auto file = openOutputFile(outputFilename);
-      if (!file)
-        return failure();
-
-      llvmModule->print(file->os(), nullptr);
-      file->keep();
+      llvmModule->print(output, nullptr);
       return success();
     });
index 32fa167..2b7ff63 100644 (file)
@@ -26,7 +26,6 @@
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Module.h"
-#include "mlir/Support/FileUtilities.h"
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
 #include "mlir/Translation.h"
 
@@ -92,19 +91,11 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(ModuleOp m) {
 
 static TranslateFromMLIRRegistration
     registration("mlir-to-nvvmir",
-                 [](ModuleOp module, llvm::StringRef outputFilename) {
-                   if (!module)
-                     return failure();
-
+                 [](ModuleOp module, llvm::raw_ostream &output) {
                    auto llvmModule = mlir::translateModuleToNVVMIR(module);
                    if (!llvmModule)
                      return failure();
 
-                   auto file = openOutputFile(outputFilename);
-                   if (!file)
-                     return failure();
-
-                   llvmModule->print(file->os(), nullptr);
-                   file->keep();
+                   llvmModule->print(output, nullptr);
                    return success();
                  });
index 282eae8..8b32881 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/Support/FileUtilities.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/TranslateClParser.h"
 #include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/ToolOutputFile.h"
 
 using namespace mlir;
 
@@ -44,7 +47,23 @@ int main(int argc, char **argv) {
                            llvm::cl::Required);
   llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR translation driver\n");
 
+  std::string errorMessage;
+  auto input = openInputFile(inputFilename, &errorMessage);
+  if (!input) {
+    llvm::errs() << errorMessage << "\n";
+    return 1;
+  }
+
+  auto output = openOutputFile(outputFilename, &errorMessage);
+  if (!output) {
+    llvm::errs() << errorMessage << "\n";
+    return 1;
+  }
+
   MLIRContext context;
-  return failed(
-      (*translationRequested)(inputFilename, outputFilename, &context));
+  if (failed((*translationRequested)(std::move(input), output->os(), &context)))
+    return 1;
+
+  output->keep();
+  return 0;
 }