Add Instance Specific Pass Options.
authorMLIR Team <no-reply@google.com>
Wed, 9 Oct 2019 01:23:13 +0000 (18:23 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 9 Oct 2019 01:23:43 +0000 (18:23 -0700)
This allows individual passes to define options structs and for these options to be parsed per instance of the pass while building the pass pipeline from the command line provided textual specification.

The user can specify these per-instance pipeline options like so:
```
struct MyPassOptions : public PassOptions<MyPassOptions> {
  Option<int> exampleOption{*this, "flag-name", llvm::cl::desc("...")};
  List<int> exampleListOption{*this, "list-flag-name", llvm::cl::desc("...")};
};

static PassRegistration<MyPass, MyPassOptions> pass("my-pass", "description");
```

PiperOrigin-RevId: 273650140

mlir/g3doc/WritingAPass.md
mlir/include/mlir/Pass/PassRegistry.h
mlir/lib/Pass/PassRegistry.cpp
mlir/lib/Support/MlirOptMain.cpp
mlir/test/Pass/pipeline-options-parsing.mlir [new file with mode: 0644]
mlir/test/Transforms/parametric-tiling.mlir
mlir/test/lib/Pass/TestPassManager.cpp
mlir/test/lib/Transforms/TestLoopParametricTiling.cpp

index ae371d7..11637df 100644 (file)
@@ -362,11 +362,11 @@ void pipelineBuilder(OpPassManager &pm) {
 }
 
 // Register an existing pipeline builder function.
-static PassPipelineRegistration pipeline(
+static PassPipelineRegistration<> pipeline(
   "command-line-arg", "description", pipelineBuilder);
 
 // Register an inline pipeline builder.
-static PassPipelineRegistration pipeline(
+static PassPipelineRegistration<> pipeline(
   "command-line-arg", "description", [](OpPassManager &pm) {
     pm.addPass(std::make_unique<MyPass>());
     pm.addPass(std::make_unique<MyOtherPass>());
@@ -377,7 +377,7 @@ Pipeline registration also allows for simplified registration of
 specifializations for existing passes:
 
 ```c++
-static PassPipelineRegistration foo10(
+static PassPipelineRegistration<> foo10(
     "foo-10", "Foo Pass 10", [] { return std::make_unique<FooPass>(10); } );
 ```
 
@@ -401,15 +401,19 @@ pipeline description. The syntax for this specification is as follows:
 
 ```ebnf
 pipeline          ::= op-name `(` pipeline-element (`,` pipeline-element)* `)`
-pipeline-element  ::= pipeline | pass-name | pass-pipeline-name
+pipeline-element  ::= pipeline | (pass-name | pass-pipeline-name) options?
+options           ::= '{' (key ('=' value)?)+ '}'
 ```
 
 *   `op-name`
-    *   This corresponds to the mneumonic name of an operation to run passes on,
+    *   This corresponds to the mnemonic name of an operation to run passes on,
         e.g. `func` or `module`.
 *   `pass-name` | `pass-pipeline-name`
     *   This corresponds to the command-line argument of a registered pass or
         pass pipeline, e.g. `cse` or `canonicalize`.
+*   `options`
+    *   Options are pass specific key value pairs that are handled as described
+        in the instance specific pass options section.
 
 For example, the following pipeline:
 
@@ -423,6 +427,35 @@ Can also be specified as (via the `-pass-pipeline` flag):
 $ mlir-opt foo.mlir -pass-pipeline='func(cse, canonicalize), lower-to-llvm'
 ```
 
+### Instance Specific Pass Options
+
+Options may be specified for a parametric pass. Individual options are defined
+using `llvm::cl::opt` flag definition rules. These options will then be parsed
+at pass construction time independently for each instance of the pass. The
+`PassRegistration` and `PassPipelineRegistration` templates take an additional
+optional template parameter that is the Option struct definition to be used for
+that pass. To use pass specific options, create a class that inherits from
+`mlir::PassOptions` and then add a new constructor that takes `const
+MyPassOptions&` and constructs the pass. When using `PassPipelineRegistration`,
+the constructor now takes a function with the signature `void (OpPassManager
+&pm, const MyPassOptions&)` which should construct the passes from the options
+and pass them to the pm. The user code will look like the following:
+
+```c++
+class MyPass ... {
+public:
+  MyPass(const MyPassOptions& options) ...
+};
+
+struct MyPassOptions : public PassOptions<MyPassOptions> {
+  // These just forward onto llvm::cl::list and llvm::cl::opt respectively.
+  Option<int> exampleOption{*this, "flag-name", llvm::cl::desc("...")};
+  List<int> exampleListOption{*this, "list-flag-name", llvm::cl::desc("...")};
+};
+
+static PassRegistration<MyPass, MyPassOptions> pass("my-pass", "description");
+```
+
 ## Pass Instrumentation
 
 MLIR provides a customizable framework to instrument pass execution and analysis
index 6379b3d..356b13e 100644 (file)
@@ -24,6 +24,7 @@
 #define MLIR_PASS_PASSREGISTRY_H_
 
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/STLExtras.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/CommandLine.h"
 #include <memory>
 
 namespace mlir {
-struct LogicalResult;
 class OpPassManager;
 class Pass;
 
-/// A registry function that adds passes to the given pass manager.
-using PassRegistryFunction = std::function<void(OpPassManager &)>;
-
-using PassAllocatorFunction = std::function<std::unique_ptr<Pass>()>;
+/// A registry function that adds passes to the given pass manager. This should
+/// also parse options and return success() if parsing succeeded.
+using PassRegistryFunction =
+    std::function<LogicalResult(OpPassManager &, StringRef options)>;
 
 /// A special type used by transformation passes to provide an address that can
 /// act as a unique identifier during pass registration.
@@ -53,11 +53,13 @@ using PassID = ClassID;
 /// to invoke via mlir-opt, description, pass pipeline builder).
 class PassRegistryEntry {
 public:
-  /// Adds this pass registry entry to the given pass manager.
-  void addToPipeline(OpPassManager &pm) const {
+  /// Adds this pass registry entry to the given pass manager. `options` is
+  /// an opaque string that will be parsed by the builder. The success of
+  /// parsing will be returned.
+  LogicalResult addToPipeline(OpPassManager &pm, StringRef options) const {
     assert(builder &&
            "cannot call addToPipeline on PassRegistryEntry without builder");
-    builder(pm);
+    return builder(pm, options);
   }
 
   /// Returns the command line option that may be passed to 'mlir-opt' that will
@@ -97,7 +99,8 @@ public:
   /// PassInfo constructor should not be invoked directly, instead use
   /// PassRegistration or registerPass.
   PassInfo(StringRef arg, StringRef description, const PassID *passID,
-           PassAllocatorFunction allocator);
+           PassRegistryFunction allocator)
+      : PassRegistryEntry(arg, description, allocator) {}
 };
 
 //===----------------------------------------------------------------------===//
@@ -112,7 +115,89 @@ void registerPassPipeline(StringRef arg, StringRef description,
 /// Register a specific dialect pass allocator function with the system,
 /// typically used through the PassRegistration template.
 void registerPass(StringRef arg, StringRef description, const PassID *passID,
-                  const PassAllocatorFunction &function);
+                  const PassRegistryFunction &function);
+
+namespace detail {
+/// Base class for PassOptions<T> that holds all of the non-CRTP features.
+class PassOptionsBase : protected llvm::cl::SubCommand {
+public:
+  /// This class represents a specific pass option, with a provided data type.
+  template <typename DataType> struct Option : public llvm::cl::opt<DataType> {
+    template <typename... Args>
+    Option(PassOptionsBase &parent, StringRef arg, Args &&... args)
+        : llvm::cl::opt<DataType>(arg, llvm::cl::sub(parent),
+                                  std::forward<Args>(args)...) {
+      assert(!this->isPositional() && !this->isSink() &&
+             "sink and positional options are not supported");
+    }
+  };
+
+  /// This class represents a specific pass option that contains a list of
+  /// values of the provided data type.
+  template <typename DataType> struct List : public llvm::cl::list<DataType> {
+    template <typename... Args>
+    List(PassOptionsBase &parent, StringRef arg, Args &&... args)
+        : llvm::cl::list<DataType>(arg, llvm::cl::sub(parent),
+                                   std::forward<Args>(args)...) {
+      assert(!this->isPositional() && !this->isSink() &&
+             "sink and positional options are not supported");
+    }
+  };
+
+  /// Parse options out as key=value pairs that can then be handed off to the
+  /// `llvm::cl` command line passing infrastructure. Everything is space
+  /// separated.
+  LogicalResult parseFromString(StringRef options);
+};
+} // end namespace detail
+
+/// Subclasses of PassOptions provide a set of options that can be used to
+/// initialize a pass instance. See PassRegistration for usage details.
+///
+/// Usage:
+///
+/// struct MyPassOptions : PassOptions<MyPassOptions> {
+///   List<int> someListFlag{
+///        *this, "flag-name", llvm::cl::MiscFlags::CommaSeparated,
+///        llvm::cl::desc("...")};
+/// };
+template <typename T> class PassOptions : public detail::PassOptionsBase {
+public:
+  /// Factory that parses the provided options and returns a unique_ptr to the
+  /// struct.
+  static std::unique_ptr<T> createFromString(StringRef options) {
+    auto result = std::make_unique<T>();
+    if (failed(result->parseFromString(options)))
+      return nullptr;
+    return result;
+  }
+};
+
+/// A default empty option struct to be used for passes that do not need to take
+/// any options.
+struct EmptyPassOptions : public PassOptions<EmptyPassOptions> {};
+
+namespace detail {
+
+// Calls `pm.addPass(std::move(pass))` to avoid including the PassManager
+// header. Only used in `makePassRegistryFunction`.
+void addPassToPassManager(OpPassManager &pm, std::unique_ptr<Pass> pass);
+
+// Helper function which constructs a PassRegistryFunction that parses options
+// into a struct of type `Options` and then calls constructor(options) to
+// build the pass.
+template <typename Options, typename PassConstructor>
+PassRegistryFunction makePassRegistryFunction(PassConstructor constructor) {
+  return [=](OpPassManager &pm, StringRef optionsStr) {
+    Options options;
+    if (failed(options.parseFromString(optionsStr)))
+      return failure();
+    addPassToPassManager(pm, constructor(options));
+    return success();
+  };
+}
+
+} // end namespace detail
 
 /// PassRegistration provides a global initializer that registers a Pass
 /// allocation routine for a concrete pass instance.  The third argument is
@@ -122,18 +207,47 @@ void registerPass(StringRef arg, StringRef description, const PassID *passID,
 /// Usage:
 ///
 ///   // At namespace scope.
-///   static PassRegistration<MyPass> Unused("unused", "Unused pass");
-template <typename ConcretePass> struct PassRegistration {
+///   static PassRegistration<MyPass> reg("my-pass", "My Pass Description.");
+///
+///   // Same, but also providing an Options struct.
+///   static PassRegistration<MyPass, MyPassOptions> reg("my-pass", "Docs...");
+template <typename ConcretePass, typename Options = EmptyPassOptions>
+struct PassRegistration {
+
   PassRegistration(StringRef arg, StringRef description,
-                   const PassAllocatorFunction &constructor) {
-    registerPass(arg, description, PassID::getID<ConcretePass>(), constructor);
+                   const std::function<std::unique_ptr<Pass>(const Options &)>
+                       &constructor) {
+    registerPass(arg, description, PassID::getID<ConcretePass>(),
+                 detail::makePassRegistryFunction<Options>(constructor));
   }
 
   PassRegistration(StringRef arg, StringRef description) {
-    PassAllocatorFunction constructor = [] {
-      return std::make_unique<ConcretePass>();
-    };
-    registerPass(arg, description, PassID::getID<ConcretePass>(), constructor);
+    registerPass(
+        arg, description, PassID::getID<ConcretePass>(),
+        detail::makePassRegistryFunction<Options>([](const Options &options) {
+          return std::make_unique<ConcretePass>(options);
+        }));
+  }
+};
+
+/// Convenience specialization of PassRegistration for EmptyPassOptions that
+/// does not pass an empty options struct to the pass constructor.
+template <typename ConcretePass>
+struct PassRegistration<ConcretePass, EmptyPassOptions> {
+  PassRegistration(StringRef arg, StringRef description,
+                   const std::function<std::unique_ptr<Pass>()> &constructor) {
+    registerPass(
+        arg, description, PassID::getID<ConcretePass>(),
+        detail::makePassRegistryFunction<EmptyPassOptions>(
+            [=](const EmptyPassOptions &options) { return constructor(); }));
+  }
+
+  PassRegistration(StringRef arg, StringRef description) {
+    registerPass(arg, description, PassID::getID<ConcretePass>(),
+                 detail::makePassRegistryFunction<EmptyPassOptions>(
+                     [](const EmptyPassOptions &options) {
+                       return std::make_unique<ConcretePass>();
+                     }));
   }
 };
 
@@ -150,17 +264,34 @@ template <typename ConcretePass> struct PassRegistration {
 ///
 ///   static PassPipelineRegistration Unused("unused", "Unused pass",
 ///                                          pipelineBuilder);
-struct PassPipelineRegistration {
-  PassPipelineRegistration(StringRef arg, StringRef description,
-                           PassRegistryFunction builder) {
-    registerPassPipeline(arg, description, builder);
+template <typename Options = EmptyPassOptions> struct PassPipelineRegistration {
+  PassPipelineRegistration(
+      StringRef arg, StringRef description,
+      std::function<void(OpPassManager &, const Options &options)> builder) {
+    registerPassPipeline(arg, description,
+                         [builder](OpPassManager &pm, StringRef optionsStr) {
+                           Options options;
+                           if (failed(options.parseFromString(optionsStr)))
+                             return failure();
+                           builder(pm, options);
+                           return success();
+                         });
   }
+};
 
-  /// Constructor that accepts a pass allocator function instead of the standard
-  /// registry function. This is useful for registering specializations of
-  /// existing passes.
+/// Convenience specialization of PassPipelineRegistration for EmptyPassOptions
+/// that does not pass an empty options struct to the pass builder function.
+template <> struct PassPipelineRegistration<EmptyPassOptions> {
   PassPipelineRegistration(StringRef arg, StringRef description,
-                           PassAllocatorFunction allocator);
+                           std::function<void(OpPassManager &)> builder) {
+    registerPassPipeline(arg, description,
+                         [builder](OpPassManager &pm, StringRef optionsStr) {
+                           if (!optionsStr.empty())
+                             return failure();
+                           builder(pm);
+                           return success();
+                         });
+  }
 };
 
 /// This function parses the textual representation of a pass pipeline, and adds
@@ -198,7 +329,9 @@ public:
   bool contains(const PassRegistryEntry *entry) const;
 
   /// Adds the passes defined by this parser entry to the given pass manager.
-  void addToPipeline(OpPassManager &pm) const;
+  /// Returns failure() if the pass could not be properly constructed due
+  /// to options parsing.
+  LogicalResult addToPipeline(OpPassManager &pm) const;
 
 private:
   std::unique_ptr<detail::PassPipelineCLParserImpl> impl;
index 5d46370..397fef3 100644 (file)
@@ -34,24 +34,16 @@ static llvm::ManagedStatic<llvm::DenseMap<const PassID *, PassInfo>>
 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
     passPipelineRegistry;
 
-/// Utility to create a default registry function from a pass instance.
-static PassRegistryFunction
-buildDefaultRegistryFn(PassAllocatorFunction allocator) {
-  return [=](OpPassManager &pm) { pm.addPass(allocator()); };
+// Helper to avoid exposing OpPassManager.
+void mlir::detail::addPassToPassManager(OpPassManager &pm,
+                                        std::unique_ptr<Pass> pass) {
+  pm.addPass(std::move(pass));
 }
 
 //===----------------------------------------------------------------------===//
 // PassPipelineInfo
 //===----------------------------------------------------------------------===//
 
-/// Constructor that accepts a pass allocator function instead of the standard
-/// registry function. This is useful for registering specializations of
-/// existing passes.
-PassPipelineRegistration::PassPipelineRegistration(
-    StringRef arg, StringRef description, PassAllocatorFunction allocator) {
-  registerPassPipeline(arg, description, buildDefaultRegistryFn(allocator));
-}
-
 void mlir::registerPassPipeline(StringRef arg, StringRef description,
                                 const PassRegistryFunction &function) {
   PassPipelineInfo pipelineInfo(arg, description, function);
@@ -64,13 +56,9 @@ void mlir::registerPassPipeline(StringRef arg, StringRef description,
 // PassInfo
 //===----------------------------------------------------------------------===//
 
-PassInfo::PassInfo(StringRef arg, StringRef description, const PassID *passID,
-                   PassAllocatorFunction allocator)
-    : PassRegistryEntry(arg, description, buildDefaultRegistryFn(allocator)) {}
-
 void mlir::registerPass(StringRef arg, StringRef description,
                         const PassID *passID,
-                        const PassAllocatorFunction &function) {
+                        const PassRegistryFunction &function) {
   PassInfo passInfo(arg, description, passID, function);
   bool inserted = passRegistry->try_emplace(passID, passInfo).second;
   assert(inserted && "Pass registered multiple times");
@@ -86,6 +74,52 @@ const PassInfo *mlir::Pass::lookupPassInfo(const PassID *passID) {
 }
 
 //===----------------------------------------------------------------------===//
+// PassOptions
+//===----------------------------------------------------------------------===//
+
+LogicalResult PassOptionsBase::parseFromString(StringRef options) {
+  // TODO(parkers): Handle escaping strings.
+  // NOTE: `options` is modified in place to always refer to the unprocessed
+  // part of the string.
+  while (!options.empty()) {
+    size_t spacePos = options.find(' ');
+    StringRef arg = options;
+    if (spacePos != StringRef::npos) {
+      arg = options.substr(0, spacePos);
+      options = options.substr(spacePos + 1);
+    } else {
+      options = StringRef();
+    }
+    if (arg.empty())
+      continue;
+
+    // At this point, arg refers to everything that is non-space in options
+    // upto the next space, and options refers to the rest of the string after
+    // that point.
+
+    // Split the individual option on '=' to form key and value. If there is no
+    // '=', then value is `StringRef()`.
+    size_t equalPos = arg.find('=');
+    StringRef key = arg;
+    StringRef value;
+    if (equalPos != StringRef::npos) {
+      key = arg.substr(0, equalPos);
+      value = arg.substr(equalPos + 1);
+    }
+    auto it = OptionsMap.find(key);
+    if (it == OptionsMap.end()) {
+      llvm::errs() << "<Pass-Options-Parser>: no such option " << key << "\n";
+
+      return failure();
+    }
+    if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
+      return failure();
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // TextualPassPipeline Parser
 //===----------------------------------------------------------------------===//
 
@@ -98,7 +132,7 @@ public:
   LogicalResult initialize(StringRef text, raw_ostream &errorStream);
 
   /// Add the internal pipeline elements to the provided pass manager.
-  void addToPipeline(OpPassManager &pm) const;
+  LogicalResult addToPipeline(OpPassManager &pm) const;
 
 private:
   /// A functor used to emit errors found during pipeline handling. The first
@@ -117,6 +151,7 @@ private:
     PipelineElement(StringRef name) : name(name), registryEntry(nullptr) {}
 
     StringRef name;
+    StringRef options;
     const PassRegistryEntry *registryEntry;
     std::vector<PipelineElement> innerPipeline;
   };
@@ -137,8 +172,8 @@ private:
                                        ErrorHandlerT errorHandler);
 
   /// Add the given pipeline elements to the provided pass manager.
-  void addToPipeline(ArrayRef<PipelineElement> elements,
-                     OpPassManager &pm) const;
+  LogicalResult addToPipeline(ArrayRef<PipelineElement> elements,
+                              OpPassManager &pm) const;
 
   std::vector<PipelineElement> pipeline;
 };
@@ -167,8 +202,8 @@ LogicalResult TextualPipeline::initialize(StringRef text,
 }
 
 /// Add the internal pipeline elements to the provided pass manager.
-void TextualPipeline::addToPipeline(OpPassManager &pm) const {
-  addToPipeline(pipeline, pm);
+LogicalResult TextualPipeline::addToPipeline(OpPassManager &pm) const {
+  return addToPipeline(pipeline, pm);
 }
 
 /// Parse the given pipeline text into the internal pipeline vector. This
@@ -179,21 +214,36 @@ LogicalResult TextualPipeline::parsePipelineText(StringRef text,
   SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
   for (;;) {
     std::vector<PipelineElement> &pipeline = *pipelineStack.back();
-    size_t pos = text.find_first_of(",()");
+    size_t pos = text.find_first_of(",(){");
     pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
 
     // If we have a single terminating name, we're done.
     if (pos == text.npos)
       break;
 
-    char sep = text[pos];
-    text = text.substr(pos + 1);
+    text = text.substr(pos);
+    char sep = text[0];
 
-    // Just a name ending in a comma, continue.
-    if (sep == ',')
-      continue;
+    // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
+    if (sep == '{') {
+      text = text.substr(1);
+
+      // Skip over everything until the closing '}' and store as options.
+      size_t close = text.find('}');
+
+      // TODO(parkers): Handle skipping over quoted sub-strings.
+      if (close == StringRef::npos) {
+        return errorHandler(
+            /*rawLoc=*/text.data() - 1,
+            "missing closing '}' while processing pass options");
+      }
+      pipeline.back().options = text.substr(0, close);
+      text = text.substr(close + 1);
+
+      // Skip checking for '(' because nested pipelines cannot have options.
+    } else if (sep == '(') {
+      text = text.substr(1);
 
-    if (sep == '(') {
       // Push the inner pipeline onto the stack to continue processing.
       pipelineStack.push_back(&pipeline.back().innerPipeline);
       continue;
@@ -201,8 +251,7 @@ LogicalResult TextualPipeline::parsePipelineText(StringRef text,
 
     // When handling the close parenthesis, we greedily consume them to avoid
     // empty strings in the pipeline.
-    assert(sep == ')' && "Bogus separator!");
-    do {
+    while (text.consume_front(")")) {
       // If we try to pop the outer pipeline we have unbalanced parentheses.
       if (pipelineStack.size() == 1)
         return errorHandler(/*rawLoc=*/text.data() - 1,
@@ -210,7 +259,7 @@ LogicalResult TextualPipeline::parsePipelineText(StringRef text,
                             "parentheses while parsing pipeline");
 
       pipelineStack.pop_back();
-    } while (text.consume_front(")"));
+    }
 
     // Check if we've finished parsing.
     if (text.empty())
@@ -276,14 +325,17 @@ TextualPipeline::resolvePipelineElement(PipelineElement &element,
 }
 
 /// Add the given pipeline elements to the provided pass manager.
-void TextualPipeline::addToPipeline(ArrayRef<PipelineElement> elements,
-                                    OpPassManager &pm) const {
+LogicalResult TextualPipeline::addToPipeline(ArrayRef<PipelineElement> elements,
+                                             OpPassManager &pm) const {
   for (auto &elt : elements) {
-    if (elt.registryEntry)
-      elt.registryEntry->addToPipeline(pm);
-    else
-      addToPipeline(elt.innerPipeline, pm.nest(elt.name));
+    if (elt.registryEntry) {
+      if (failed(elt.registryEntry->addToPipeline(pm, elt.options)))
+        return failure();
+    } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name)))) {
+      return failure();
+    }
   }
+  return success();
 }
 
 /// This function parses the textual representation of a pass pipeline, and adds
@@ -295,7 +347,8 @@ LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
   TextualPipeline pipelineParser;
   if (failed(pipelineParser.initialize(pipeline, errorStream)))
     return failure();
-  pipelineParser.addToPipeline(pm);
+  if (failed(pipelineParser.addToPipeline(pm)))
+    return failure();
   return success();
 }
 
@@ -315,6 +368,10 @@ struct PassArgData {
   /// or pass pipeline.
   const PassRegistryEntry *registryEntry;
 
+  /// This field is set when instance specific pass options have been provided
+  /// on the command line.
+  StringRef options;
+
   /// This field is used when the parsed option corresponds to an explicit
   /// pipeline.
   TextualPipeline pipeline;
@@ -396,7 +453,10 @@ bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
     return failed(value.pipeline.initialize(arg, llvm::errs()));
 
   // Otherwise, default to the base for handling.
-  return llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value);
+  if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))
+    return true;
+  value.options = arg;
+  return false;
 }
 
 //===----------------------------------------------------------------------===//
@@ -437,11 +497,14 @@ bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const {
 }
 
 /// Adds the passes defined by this parser entry to the given pass manager.
-void PassPipelineCLParser::addToPipeline(OpPassManager &pm) const {
+LogicalResult PassPipelineCLParser::addToPipeline(OpPassManager &pm) const {
   for (auto &passIt : impl->passList) {
-    if (passIt.registryEntry)
-      passIt.registryEntry->addToPipeline(pm);
-    else
-      passIt.pipeline.addToPipeline(pm);
+    if (passIt.registryEntry) {
+      if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options)))
+        return failure();
+    } else if (failed(passIt.pipeline.addToPipeline(pm))) {
+      return failure();
+    }
   }
+  return success();
 }
index 9da2619..c256e97 100644 (file)
@@ -59,7 +59,8 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
   applyPassManagerCLOptions(pm);
 
   // Build the provided pipeline.
-  passPipeline.addToPipeline(pm);
+  if (failed(passPipeline.addToPipeline(pm)))
+    return failure();
 
   // Run the pipeline.
   if (failed(pm.run(*module)))
diff --git a/mlir/test/Pass/pipeline-options-parsing.mlir b/mlir/test/Pass/pipeline-options-parsing.mlir
new file mode 100644 (file)
index 0000000..c45b2fc
--- /dev/null
@@ -0,0 +1,19 @@
+// RUN: not mlir-opt %s -pass-pipeline='module(test-module-pass{)' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_1 %s
+// RUN: not mlir-opt %s -pass-pipeline='module(test-module-pass{test-option=3})' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_2 %s
+// RUN: not mlir-opt %s -pass-pipeline='module(test-options-pass{list=3}, test-module-pass{invalid-option=3})' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_3 %s
+// RUN: not mlir-opt %s -pass-pipeline='test-options-pass{list=3 list=notaninteger}' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_4 %s
+// RUN: not mlir-opt %s -pass-pipeline='func(test-options-pass{list=1,2,3,4 list=5 string=value1 string=value2})' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_5 %s
+// RUN: mlir-opt %s -pass-pipeline='func(test-options-pass{string-list=a list=1,2,3,4 string-list=b,c list=5 string-list=d string=some_value})' 2>&1 | FileCheck --check-prefix=CHECK_1 %s
+// RUN: mlir-opt %s -test-options-pass-pipeline='list=1 string-list=a,b' 2>&1 | FileCheck --check-prefix=CHECK_2 %s
+// RUN: mlir-opt %s -pass-pipeline='module(test-options-pass{list=3}, test-options-pass{list=1,2,3,4})' 2>&1 | FileCheck --check-prefix=CHECK_3 %s
+
+// CHECK_ERROR_1: missing closing '}' while processing pass options
+// CHECK_ERROR_2: no such option test-option
+// CHECK_ERROR_3: no such option invalid-option
+// CHECK_ERROR_4: 'notaninteger' value invalid for integer argument
+// CHECK_ERROR_5: string option: may only occur zero or one times
+
+// CHECK_1: test-options-pass{list=1,2,3,4,5 string-list=a,b,c,d string=some_value}
+// CHECK_2: test-options-pass{list=1 string-list=a,b}
+// CHECK_3: test-options-pass{list=3}
+// CHECK_3-NEXT: test-options-pass{list=1,2,3,4}
index 7ee323a..4265462 100644 (file)
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -test-extract-fixed-outer-loops -test-outer-loop-sizes=7 %s | FileCheck %s --check-prefixes=COMMON,TILE_7
-// RUN: mlir-opt -test-extract-fixed-outer-loops -test-outer-loop-sizes=7,4 %s | FileCheck %s --check-prefixes=COMMON,TILE_74
+// RUN: mlir-opt -test-extract-fixed-outer-loops='test-outer-loop-sizes=7' %s | FileCheck %s --check-prefixes=COMMON,TILE_7
+// RUN: mlir-opt -test-extract-fixed-outer-loops='test-outer-loop-sizes=7,4' %s | FileCheck %s --check-prefixes=COMMON,TILE_74
 
 // COMMON-LABEL: @rectangular
 func @rectangular(%arg0: memref<?x?xf32>) {
@@ -130,4 +130,4 @@ func @triangular(%arg0: memref<?x?xf32>) {
     }
   }
   return
-}
\ No newline at end of file
+}
index d4ef46a..9e91777 100644 (file)
@@ -28,6 +28,51 @@ struct TestModulePass : public ModulePass<TestModulePass> {
 struct TestFunctionPass : public FunctionPass<TestFunctionPass> {
   void runOnFunction() final {}
 };
+
+class TestOptionsPass : public FunctionPass<TestOptionsPass> {
+public:
+  struct Options : public PassOptions<Options> {
+    List<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated,
+                         llvm::cl::desc("Example list option")};
+    List<string> stringListOption{*this, "string-list",
+                                  llvm::cl::MiscFlags::CommaSeparated,
+                                  llvm::cl::desc("Example string list option")};
+    Option<std::string> stringOption{*this, "string",
+                                     llvm::cl::desc("Example string option")};
+  };
+  TestOptionsPass(const Options &options) {
+    listOption.assign(options.listOption.begin(), options.listOption.end());
+    stringOption = options.stringOption;
+    stringListOption.assign(options.stringListOption.begin(),
+                            options.stringListOption.end());
+    // Print out a debug representation of the pass in order to allow FileCheck
+    // testing of options parsing.
+    print(llvm::errs());
+    llvm::errs() << "\n";
+  }
+
+  void print(raw_ostream &os) {
+    os << "test-options-pass{";
+    if (!listOption.empty()) {
+      os << "list=";
+      // Not interleaveComma to avoid spaces between the elements.
+      interleave(listOption, os, ",");
+    }
+    if (!stringListOption.empty()) {
+      os << " string-list=";
+      interleave(stringListOption, os, ",");
+    }
+    if (!stringOption.empty())
+      os << " string=" << stringOption;
+    os << "}";
+  }
+
+  void runOnFunction() final {}
+
+  SmallVector<int64_t, 4> listOption;
+  SmallVector<std::string, 4> stringListOption;
+  std::string stringOption;
+};
 } // namespace
 
 static void testNestedPipeline(OpPassManager &pm) {
@@ -48,15 +93,26 @@ static void testNestedPipelineTextual(OpPassManager &pm) {
   (void)parsePassPipeline("test-pm-nested-pipeline", pm);
 }
 
+static PassRegistration<TestOptionsPass, TestOptionsPass::Options>
+    reg("test-options-pass", "Test options parsing capabilities");
+
 static PassRegistration<TestModulePass>
     unusedMP("test-module-pass", "Test a module pass in the pass manager");
 static PassRegistration<TestFunctionPass>
     unusedFP("test-function-pass", "Test a function pass in the pass manager");
 
-static PassPipelineRegistration
+static PassPipelineRegistration<>
     unused("test-pm-nested-pipeline",
            "Test a nested pipeline in the pass manager", testNestedPipeline);
-static PassPipelineRegistration
+static PassPipelineRegistration<>
     unusedTextual("test-textual-pm-nested-pipeline",
                   "Test a nested pipeline in the pass manager",
                   testNestedPipelineTextual);
+
+static PassPipelineRegistration<TestOptionsPass::Options>
+    registerOptionsPassPipeline(
+        "test-options-pass-pipeline",
+        "Parses options using pass pipeline registration",
+        [](OpPassManager &pm, const TestOptionsPass::Options &options) {
+          pm.addPass(std::make_unique<TestOptionsPass>(options));
+        });
index bce1e08..9a8e191 100644 (file)
 
 using namespace mlir;
 
-static llvm::cl::list<int> clOuterLoopSizes(
-    "test-outer-loop-sizes", llvm::cl::MiscFlags::CommaSeparated,
-    llvm::cl::desc(
-        "fixed number of iterations that the outer loops should have"));
-
 namespace {
+
 // Extracts fixed-range loops for top-level loop nests with ranges defined in
 // the pass constructor.  Assumes loops are permutable.
 class SimpleParametricLoopTilingPass
     : public FunctionPass<SimpleParametricLoopTilingPass> {
 public:
+  struct Options : public PassOptions<Options> {
+    List<int> clOuterLoopSizes{
+        *this, "test-outer-loop-sizes", llvm::cl::MiscFlags::CommaSeparated,
+        llvm::cl::desc(
+            "fixed number of iterations that the outer loops should have")};
+  };
+
   explicit SimpleParametricLoopTilingPass(ArrayRef<int64_t> outerLoopSizes)
       : sizes(outerLoopSizes.begin(), outerLoopSizes.end()) {}
+  explicit SimpleParametricLoopTilingPass(const Options &options) {
+    sizes.assign(options.clOuterLoopSizes.begin(),
+                 options.clOuterLoopSizes.end());
+  }
 
   void runOnFunction() override {
     FuncOp func = getFunction();
@@ -60,13 +67,8 @@ mlir::createSimpleParametricTilingPass(ArrayRef<int64_t> outerLoopSizes) {
   return std::make_unique<SimpleParametricLoopTilingPass>(outerLoopSizes);
 }
 
-static PassRegistration<SimpleParametricLoopTilingPass>
+static PassRegistration<SimpleParametricLoopTilingPass,
+                        SimpleParametricLoopTilingPass::Options>
     reg("test-extract-fixed-outer-loops",
         "test application of parametric tiling to the outer loops so that the "
-        "ranges of outer loops become static",
-        [] {
-          auto pass = std::make_unique<SimpleParametricLoopTilingPass>(
-              ArrayRef<int64_t>{});
-          pass->sizes.assign(clOuterLoopSizes.begin(), clOuterLoopSizes.end());
-          return pass;
-        });
+        "ranges of outer loops become static");