[mlir][Translation] Allow specifying an expected input alignment for "ToMLIR" transla...
authorRiver Riddle <riddleriver@gmail.com>
Tue, 15 Nov 2022 02:16:28 +0000 (18:16 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Wed, 16 Nov 2022 01:22:41 +0000 (17:22 -0800)
This allows for ensuring that alignment requirements on translation
inputs are satisfied.

Differential Revision: https://reviews.llvm.org/D137999

mlir/include/mlir/Tools/mlir-translate/Translation.h
mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
mlir/lib/Tools/mlir-translate/Translation.cpp

index d3cd817..80c4e37 100644 (file)
@@ -47,9 +47,44 @@ using TranslateFromMLIRFunction =
 using TranslateFunction = std::function<LogicalResult(
     llvm::SourceMgr &sourceMgr, llvm::raw_ostream &output, MLIRContext *)>;
 
+/// This class contains all of the components necessary for performing a
+/// translation.
+class Translation {
+public:
+  Translation() = default;
+  Translation(TranslateFunction function, StringRef description,
+              Optional<llvm::Align> inputAlignment)
+      : function(std::move(function)), description(description),
+        inputAlignment(inputAlignment) {}
+
+  /// Return the description of this translation.
+  StringRef getDescription() const { return description; }
+
+  /// Return the optional alignment desired for the input of the translation.
+  Optional<llvm::Align> getInputAlignment() const { return inputAlignment; }
+
+  /// Invoke the translation function with the given input and output streams.
+  LogicalResult operator()(llvm::SourceMgr &sourceMgr,
+                           llvm::raw_ostream &output,
+                           MLIRContext *context) const {
+    return function(sourceMgr, output, context);
+  }
+
+private:
+  /// The underlying translation function.
+  TranslateFunction function;
+
+  /// The description of the translation.
+  StringRef description;
+
+  /// An optional alignment desired for the input of the translation.
+  Optional<llvm::Align> inputAlignment;
+};
+
 /// Use Translate[ToMLIR|FromMLIR]Registration as an initializer that
 /// registers a function and associates it with name. This requires that a
-/// translation has not been registered to a given name.
+/// translation has not been registered to a given name. `inputAlign` is an
+/// optional expected alignment for the input data.
 ///
 /// Usage:
 ///
@@ -62,10 +97,14 @@ using TranslateFunction = std::function<LogicalResult(
 ///
 /// \{
 struct TranslateToMLIRRegistration {
-  TranslateToMLIRRegistration(llvm::StringRef name, llvm::StringRef description,
-                              const TranslateSourceMgrToMLIRFunction &function);
-  TranslateToMLIRRegistration(llvm::StringRef name, llvm::StringRef description,
-                              const TranslateStringRefToMLIRFunction &function);
+  TranslateToMLIRRegistration(
+      llvm::StringRef name, llvm::StringRef description,
+      const TranslateSourceMgrToMLIRFunction &function,
+      Optional<llvm::Align> inputAlignment = llvm::None);
+  TranslateToMLIRRegistration(
+      llvm::StringRef name, llvm::StringRef description,
+      const TranslateStringRefToMLIRFunction &function,
+      Optional<llvm::Align> inputAlignment = llvm::None);
 };
 
 struct TranslateFromMLIRRegistration {
@@ -99,7 +138,7 @@ struct TranslateRegistration {
 /// \}
 
 /// A command line parser for translation functions.
-struct TranslationParser : public llvm::cl::parser<const TranslateFunction *> {
+struct TranslationParser : public llvm::cl::parser<const Translation *> {
   TranslationParser(llvm::cl::Option &opt);
 
   void printOptionInfo(const llvm::cl::Option &o,
index ef2545b..51b21f2 100644 (file)
@@ -56,7 +56,7 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
   llvm::InitLLVM y(argc, argv);
 
   // Add flags for all the registered translations.
-  llvm::cl::opt<const TranslateFunction *, false, TranslationParser>
+  llvm::cl::opt<const Translation *, false, TranslationParser>
       translationRequested("", llvm::cl::desc("Translation to perform"),
                            llvm::cl::Required);
   registerAsmPrinterCLOptions();
@@ -65,7 +65,11 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
   llvm::cl::ParseCommandLineOptions(argc, argv, toolName);
 
   std::string errorMessage;
-  auto input = openInputFile(inputFilename, &errorMessage);
+  std::unique_ptr<llvm::MemoryBuffer> input;
+  if (auto inputAlignment = translationRequested->getInputAlignment())
+    input = openInputFile(inputFilename, *inputAlignment, &errorMessage);
+  else
+    input = openInputFile(inputFilename, &errorMessage);
   if (!input) {
     llvm::errs() << errorMessage << "\n";
     return failure();
index ab86cd0..548e3f9 100644 (file)
@@ -40,34 +40,30 @@ void mlir::registerTranslationCLOptions() { *clOptions; }
 // Translation Registry
 //===----------------------------------------------------------------------===//
 
-struct TranslationBundle {
-  TranslateFunction translateFunction;
-  StringRef translateDescription;
-};
-
-/// Get the mutable static map between registered file-to-file MLIR translations
-/// and TranslateFunctions with its description that perform those translations.
-static llvm::StringMap<TranslationBundle> &getTranslationRegistry() {
-  static llvm::StringMap<TranslationBundle> translationBundle;
+/// Get the mutable static map between registered file-to-file MLIR
+/// translations.
+static llvm::StringMap<Translation> &getTranslationRegistry() {
+  static llvm::StringMap<Translation> translationBundle;
   return translationBundle;
 }
 
 /// Register the given translation.
 static void registerTranslation(StringRef name, StringRef description,
+                                Optional<llvm::Align> inputAlignment,
                                 const TranslateFunction &function) {
-  auto &translationRegistry = getTranslationRegistry();
-  if (translationRegistry.find(name) != translationRegistry.end())
+  auto &registry = getTranslationRegistry();
+  if (registry.count(name))
     llvm::report_fatal_error(
         "Attempting to overwrite an existing <file-to-file> function");
   assert(function &&
          "Attempting to register an empty translate <file-to-file> function");
-  translationRegistry[name].translateFunction = function;
-  translationRegistry[name].translateDescription = description;
+  registry[name] = Translation(function, description, inputAlignment);
 }
 
 TranslateRegistration::TranslateRegistration(
     StringRef name, StringRef description, const TranslateFunction &function) {
-  registerTranslation(name, description, function);
+  registerTranslation(name, description, /*inputAlignment=*/llvm::None,
+                      function);
 }
 
 //===----------------------------------------------------------------------===//
@@ -77,7 +73,7 @@ TranslateRegistration::TranslateRegistration(
 // Puts `function` into the to-MLIR translation registry unless there is already
 // a function registered for the same name.
 static void registerTranslateToMLIRFunction(
-    StringRef name, StringRef description,
+    StringRef name, StringRef description, Optional<llvm::Align> inputAlignment,
     const TranslateSourceMgrToMLIRFunction &function) {
   auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output,
                               MLIRContext *context) {
@@ -87,21 +83,23 @@ static void registerTranslateToMLIRFunction(
     op.get()->print(output);
     return success();
   };
-  registerTranslation(name, description, wrappedFn);
+  registerTranslation(name, description, inputAlignment, wrappedFn);
 }
 
 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
     StringRef name, StringRef description,
-    const TranslateSourceMgrToMLIRFunction &function) {
-  registerTranslateToMLIRFunction(name, description, function);
+    const TranslateSourceMgrToMLIRFunction &function,
+    Optional<llvm::Align> inputAlignment) {
+  registerTranslateToMLIRFunction(name, description, inputAlignment, function);
 }
 /// Wraps `function` with a lambda that extracts a StringRef from a source
 /// manager and registers the wrapper lambda as a to-MLIR conversion.
 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
     StringRef name, StringRef description,
-    const TranslateStringRefToMLIRFunction &function) {
+    const TranslateStringRefToMLIRFunction &function,
+    Optional<llvm::Align> inputAlignment) {
   registerTranslateToMLIRFunction(
-      name, description,
+      name, description, inputAlignment,
       [function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) {
         const llvm::MemoryBuffer *buffer =
             sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
@@ -117,9 +115,8 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
     StringRef name, StringRef description,
     const TranslateFromMLIRFunction &function,
     const std::function<void(DialectRegistry &)> &dialectRegistration) {
-
   registerTranslation(
-      name, description,
+      name, description, /*inputAlignment=*/llvm::None,
       [function, dialectRegistration](llvm::SourceMgr &sourceMgr,
                                       raw_ostream &output,
                                       MLIRContext *context) {
@@ -141,11 +138,9 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
 //===----------------------------------------------------------------------===//
 
 TranslationParser::TranslationParser(llvm::cl::Option &opt)
-    : llvm::cl::parser<const TranslateFunction *>(opt) {
-  for (const auto &kv : getTranslationRegistry()) {
-    addLiteralOption(kv.first(), &kv.second.translateFunction,
-                     kv.second.translateDescription);
-  }
+    : llvm::cl::parser<const Translation *>(opt) {
+  for (const auto &kv : getTranslationRegistry())
+    addLiteralOption(kv.first(), &kv.second, kv.second.getDescription());
 }
 
 void TranslationParser::printOptionInfo(const llvm::cl::Option &o,
@@ -156,5 +151,5 @@ void TranslationParser::printOptionInfo(const llvm::cl::Option &o,
                           const TranslationParser::OptionInfo *rhs) {
                          return lhs->Name.compare(rhs->Name);
                        });
-  llvm::cl::parser<const TranslateFunction *>::printOptionInfo(o, globalWidth);
+  llvm::cl::parser<const Translation *>::printOptionInfo(o, globalWidth);
 }