Refactor the way that pass options are specified.
authorRiver Riddle <riverriddle@google.com>
Mon, 23 Dec 2019 23:54:55 +0000 (15:54 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 24 Dec 2019 00:48:22 +0000 (16:48 -0800)
This change refactors pass options to be more similar to how statistics are modeled. More specifically, the options are specified directly on the pass instead of in a separate options class. (Note that the behavior and specification for pass pipelines remains the same.) This brings about several benefits:
* The specification of options is much simpler
* The round-trip format of a pass can be generated automatically
* This gives a somewhat deeper integration with "configuring" a pass, which we could potentially expose to users in the future.

PiperOrigin-RevId: 286953824

mlir/g3doc/WritingAPass.md
mlir/include/mlir/Pass/Pass.h
mlir/include/mlir/Pass/PassOptions.h
mlir/include/mlir/Pass/PassRegistry.h
mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
mlir/lib/Pass/Pass.cpp
mlir/lib/Pass/PassRegistry.cpp
mlir/test/Pass/pipeline-options-parsing.mlir
mlir/test/lib/Pass/TestPassManager.cpp
mlir/test/lib/Transforms/TestLoopParametricTiling.cpp

index 7847571..5119c46 100644 (file)
@@ -421,7 +421,8 @@ options           ::= '{' (key ('=' value)?)+ '}'
         pass pipeline, e.g. `cse` or `canonicalize`.
 *   `options`
     *   Options are pass specific key value pairs that are handled as described
-        in the instance specific pass options section.
+        in the [instance specific pass options](#instance-specific-pass-options)
+        section.
 
 For example, the following pipeline:
 
@@ -443,30 +444,47 @@ options in the format described above.
 ### Instance Specific Pass Options
 
 Options may be specified for a parametric pass. Individual options are defined
-using `llvm::cl::opt` flag definition rules. These options will then be parsed
-at pass construction time independently for each instance of the pass. The
-`PassRegistration` and `PassPipelineRegistration` templates take an additional
-optional template parameter that is the Option struct definition to be used for
-that pass. To use pass specific options, create a class that inherits from
-`mlir::PassOptions` and then add a new constructor that takes `const
-MyPassOptions&` and constructs the pass. When using `PassPipelineRegistration`,
-the constructor now takes a function with the signature `void (OpPassManager
-&pm, const MyPassOptions&)` which should construct the passes from the options
-and pass them to the pm. The user code will look like the following:
+using the [LLVM command line](https://llvm.org/docs/CommandLine.html) flag
+definition rules. These options will then be parsed at pass construction time
+independently for each instance of the pass. To provide options for passes, the
+`Option<>` and `OptionList<>` classes may be used:
 
 ```c++
-class MyPass ... {
-public:
-  MyPass(const MyPassOptions& options) ...
+struct MyPass ... {
+  /// Make sure that we have a valid default constructor and copy constructor to
+  /// make sure that the options are initialized properly.
+  MyPass() = default;
+  MyPass(const MyPass& pass) {}
+
+  // These just forward onto llvm::cl::list and llvm::cl::opt respectively.
+  Option<int> exampleOption{*this, "flag-name", llvm::cl::desc("...")};
+  ListOption<int> exampleListOption{*this, "list-flag-name",
+                                    llvm::cl::desc("...")};
 };
+```
 
-struct MyPassOptions : public PassOptions<MyPassOptions> {
+For pass pipelines, the `PassPipelineRegistration` templates take an additional
+optional template parameter that is the Option struct definition to be used for
+that pipeline. To use pipeline specific options, create a class that inherits
+from `mlir::PassPipelineOptions` that contains the desired options. When using
+`PassPipelineRegistration`, the constructor now takes a function with the
+signature `void (OpPassManager &pm, const MyPipelineOptions&)` which should
+construct the passes from the options and pass them to the pm:
+
+```c++
+struct MyPipelineOptions : public PassPipelineOptions {
   // These just forward onto llvm::cl::list and llvm::cl::opt respectively.
   Option<int> exampleOption{*this, "flag-name", llvm::cl::desc("...")};
-  List<int> exampleListOption{*this, "list-flag-name", llvm::cl::desc("...")};
+  ListOption<int> exampleListOption{*this, "list-flag-name",
+                                    llvm::cl::desc("...")};
 };
 
-static PassRegistration<MyPass, MyPassOptions> pass("my-pass", "description");
+
+static mlir::PassPipelineRegistration<MyPipelineOptions> pipeline(
+    "example-pipeline", "Run an example pipeline.",
+    [](OpPassManager &pm, const MyPipelineOptions &pipelineOptions) {
+      // Initialize the pass manager.
+    });
 ```
 
 ## Pass Statistics
index b4e8db8..bcb2973 100644 (file)
@@ -61,12 +61,40 @@ public:
   /// this is a generic OperationPass.
   Optional<StringRef> getOpName() const { return opName; }
 
+  //===--------------------------------------------------------------------===//
+  // Options
+  //===--------------------------------------------------------------------===//
+
+  /// This class represents a specific pass option, with a provided data type.
+  template <typename DataType>
+  struct Option : public detail::PassOptions::Option<DataType> {
+    template <typename... Args>
+    Option(Pass &parent, StringRef arg, Args &&... args)
+        : detail::PassOptions::Option<DataType>(parent.passOptions, arg,
+                                                std::forward<Args>(args)...) {}
+    using detail::PassOptions::Option<DataType>::operator=;
+  };
+  /// This class represents a specific pass option that contains a list of
+  /// values of the provided data type.
+  template <typename DataType>
+  struct ListOption : public detail::PassOptions::ListOption<DataType> {
+    template <typename... Args>
+    ListOption(Pass &parent, StringRef arg, Args &&... args)
+        : detail::PassOptions::ListOption<DataType>(
+              parent.passOptions, arg, std::forward<Args>(args)...) {}
+    using detail::PassOptions::ListOption<DataType>::operator=;
+  };
+
+  /// Attempt to initialize the options of this pass from the given string.
+  LogicalResult initializeOptions(StringRef options);
+
   /// Prints out the pass in the textual representation of pipelines. If this is
   /// an adaptor pass, print with the op_name(sub_pass,...) format.
-  /// Note: The default implementation uses the class name and does not respect
-  /// options used to construct the pass. Override this method to allow for your
-  /// pass to be to be round-trippable to the textual format.
-  virtual void printAsTextualPipeline(raw_ostream &os);
+  void printAsTextualPipeline(raw_ostream &os);
+
+  //===--------------------------------------------------------------------===//
+  // Statistics
+  //===--------------------------------------------------------------------===//
 
   /// This class represents a single pass statistic. This statistic functions
   /// similarly to an unsigned integer value, and may be updated and incremented
@@ -119,6 +147,10 @@ protected:
     return getPassState().analysisManager;
   }
 
+  /// Copy the option values from 'other', which is another instance of this
+  /// pass.
+  void copyOptionValuesFrom(const Pass *other);
+
 private:
   /// Forwarding function to execute this pass on the given operation.
   LLVM_NODISCARD
@@ -141,6 +173,9 @@ private:
   /// The set of statistics held by this pass.
   std::vector<Statistic *> statistics;
 
+  /// The pass options registered to this pass instance.
+  detail::PassOptions passOptions;
+
   /// Allow access to 'clone' and 'run'.
   friend class OpPassManager;
 };
@@ -204,7 +239,9 @@ protected:
 
   /// A clone method to create a copy of this pass.
   std::unique_ptr<Pass> clone() const override {
-    return std::make_unique<PassT>(*static_cast<const PassT *>(this));
+    auto newInst = std::make_unique<PassT>(*static_cast<const PassT *>(this));
+    newInst->copyOptionValuesFrom(this);
+    return newInst;
   }
 
   /// Returns the analysis for the parent operation if it exists.
index 8ebeead..0ecb7ba 100644 (file)
 
 namespace mlir {
 namespace detail {
-/// Base class for PassOptions<T> that holds all of the non-CRTP features.
-class PassOptionsBase : protected llvm::cl::SubCommand {
+/// Base container class and manager for all pass options.
+class PassOptions : protected llvm::cl::SubCommand {
+private:
+  /// This is the type-erased option base class. This provides some additional
+  /// hooks into the options that are not available via llvm::cl::Option.
+  class OptionBase {
+  public:
+    virtual ~OptionBase() = default;
+
+    /// Out of line virtual function to provide home for the class.
+    virtual void anchor();
+
+    /// Print the name and value of this option to the given stream.
+    virtual void print(raw_ostream &os) = 0;
+
+    /// Return the argument string of this option.
+    StringRef getArgStr() const { return getOption()->ArgStr; }
+
+  protected:
+    /// Return the main option instance.
+    virtual const llvm::cl::Option *getOption() const = 0;
+
+    /// Copy the value from the given option into this one.
+    virtual void copyValueFrom(const OptionBase &other) = 0;
+
+    /// Allow access to private methods.
+    friend PassOptions;
+  };
+
+  /// This is the parser that is used by pass options that use literal options.
+  /// This is a thin wrapper around the llvm::cl::parser, that exposes some
+  /// additional methods.
+  template <typename DataType>
+  struct GenericOptionParser : public llvm::cl::parser<DataType> {
+    using llvm::cl::parser<DataType>::parser;
+
+    /// Returns an argument name that maps to the specified value.
+    Optional<StringRef> findArgStrForValue(const DataType &value) {
+      for (auto &it : this->Values)
+        if (it.V.compare(value))
+          return it.Name;
+      return llvm::None;
+    }
+  };
+
+  /// The specific parser to use depending on llvm::cl parser used. This is only
+  /// necessary because we need to provide additional methods for certain data
+  /// type parsers.
+  /// TODO(riverriddle) We should upstream the methods in GenericOptionParser to
+  /// avoid the need to do this.
+  template <typename DataType>
+  using OptionParser =
+      std::conditional_t<std::is_base_of<llvm::cl::generic_parser_base,
+                                         llvm::cl::parser<DataType>>::value,
+                         GenericOptionParser<DataType>,
+                         llvm::cl::parser<DataType>>;
+
+  /// Utility methods for printing option values.
+  template <typename DataT>
+  static void printOptionValue(raw_ostream &os,
+                               GenericOptionParser<DataT> &parser,
+                               const DataT &value) {
+    if (Optional<StringRef> argStr = parser.findArgStrForValue(value))
+      os << argStr;
+    else
+      llvm_unreachable("unknown data value for option");
+  }
+  template <typename DataT, typename ParserT>
+  static void printOptionValue(raw_ostream &os, ParserT &parser,
+                               const DataT &value) {
+    os << value;
+  }
+  template <typename ParserT>
+  static void printOptionValue(raw_ostream &os, ParserT &parser,
+                               const bool &value) {
+    os << (value ? StringRef("true") : StringRef("false"));
+  }
+
 public:
   /// This class represents a specific pass option, with a provided data type.
-  template <typename DataType> struct Option : public llvm::cl::opt<DataType> {
+  template <typename DataType>
+  class Option : public llvm::cl::opt<DataType, /*ExternalStorage=*/false,
+                                      OptionParser<DataType>>,
+                 public OptionBase {
+  public:
     template <typename... Args>
-    Option(PassOptionsBase &parent, StringRef arg, Args &&... args)
-        : llvm::cl::opt<DataType>(arg, llvm::cl::sub(parent),
-                                  std::forward<Args>(args)...) {
+    Option(PassOptions &parent, StringRef arg, Args &&... args)
+        : llvm::cl::opt<DataType, /*ExternalStorage=*/false,
+                        OptionParser<DataType>>(arg, llvm::cl::sub(parent),
+                                                std::forward<Args>(args)...) {
       assert(!this->isPositional() && !this->isSink() &&
              "sink and positional options are not supported");
+      parent.options.push_back(this);
+    }
+    using llvm::cl::opt<DataType, /*ExternalStorage=*/false,
+                        OptionParser<DataType>>::operator=;
+    ~Option() override = default;
+
+  private:
+    /// Return the main option instance.
+    const llvm::cl::Option *getOption() const final { return this; }
+
+    /// Print the name and value of this option to the given stream.
+    void print(raw_ostream &os) final {
+      os << this->ArgStr << '=';
+      printOptionValue(os, this->getParser(), this->getValue());
+    }
+
+    /// Copy the value from the given option into this one.
+    void copyValueFrom(const OptionBase &other) final {
+      this->setValue(static_cast<const Option<DataType> &>(other).getValue());
     }
   };
 
   /// This class represents a specific pass option that contains a list of
   /// values of the provided data type.
-  template <typename DataType> struct List : public llvm::cl::list<DataType> {
+  template <typename DataType>
+  class ListOption : public llvm::cl::list<DataType, /*StorageClass=*/bool,
+                                           OptionParser<DataType>>,
+                     public OptionBase {
+  public:
     template <typename... Args>
-    List(PassOptionsBase &parent, StringRef arg, Args &&... args)
-        : llvm::cl::list<DataType>(arg, llvm::cl::sub(parent),
-                                   std::forward<Args>(args)...) {
+    ListOption(PassOptions &parent, StringRef arg, Args &&... args)
+        : llvm::cl::list<DataType, /*StorageClass=*/bool,
+                         OptionParser<DataType>>(arg, llvm::cl::sub(parent),
+                                                 std::forward<Args>(args)...) {
       assert(!this->isPositional() && !this->isSink() &&
              "sink and positional options are not supported");
+      parent.options.push_back(this);
+    }
+    ~ListOption() override = default;
+
+    /// Allow assigning from an ArrayRef.
+    ListOption<DataType> &operator=(ArrayRef<DataType> values) {
+      (*this)->assign(values.begin(), values.end());
+      return *this;
+    }
+
+    std::vector<DataType> *operator->() { return &*this; }
+
+  private:
+    /// Return the main option instance.
+    const llvm::cl::Option *getOption() const final { return this; }
+
+    /// Print the name and value of this option to the given stream.
+    void print(raw_ostream &os) final {
+      os << this->ArgStr << '=';
+      auto printElementFn = [&](const DataType &value) {
+        printOptionValue(os, this->getParser(), value);
+      };
+      interleave(*this, os, printElementFn, ",");
+    }
+
+    /// Copy the value from the given option into this one.
+    void copyValueFrom(const OptionBase &other) final {
+      (*this) = ArrayRef<DataType>((ListOption<DataType> &)other);
     }
   };
 
+  PassOptions() = default;
+
+  /// Copy the option values from 'other' into 'this', where 'other' has the
+  /// same options as 'this'.
+  void copyOptionValuesFrom(const PassOptions &other);
+
   /// Parse options out as key=value pairs that can then be handed off to the
   /// `llvm::cl` command line passing infrastructure. Everything is space
   /// separated.
   LogicalResult parseFromString(StringRef options);
+
+  /// Print the options held by this struct in a form that can be parsed via
+  /// 'parseFromString'.
+  void print(raw_ostream &os);
+
+private:
+  /// A list of all of the opaque options.
+  std::vector<OptionBase *> options;
 };
 } // end namespace detail
 
-/// Subclasses of PassOptions provide a set of options that can be used to
-/// initialize a pass instance. See PassRegistration for usage details.
+//===----------------------------------------------------------------------===//
+// PassPipelineOptions
+//===----------------------------------------------------------------------===//
+
+/// Subclasses of PassPipelineOptions provide a set of options that can be used
+/// to initialize a pass pipeline. See PassPipelineRegistration for usage
+/// details.
 ///
 /// Usage:
 ///
-/// struct MyPassOptions : PassOptions<MyPassOptions> {
-///   List<int> someListFlag{
+/// struct MyPipelineOptions : PassPipelineOptions<MyPassOptions> {
+///   ListOption<int> someListFlag{
 ///        *this, "flag-name", llvm::cl::MiscFlags::CommaSeparated,
 ///        llvm::cl::desc("...")};
 /// };
-template <typename T> class PassOptions : public detail::PassOptionsBase {
+template <typename T> class PassPipelineOptions : public detail::PassOptions {
 public:
   /// Factory that parses the provided options and returns a unique_ptr to the
   /// struct.
@@ -81,7 +233,8 @@ public:
 
 /// A default empty option struct to be used for passes that do not need to take
 /// any options.
-struct EmptyPassOptions : public PassOptions<EmptyPassOptions> {};
+struct EmptyPipelineOptions : public PassPipelineOptions<EmptyPipelineOptions> {
+};
 
 } // end namespace mlir
 
index e07b985..c5604c0 100644 (file)
@@ -25,6 +25,7 @@ class Pass;
 /// also parse options and return success() if parsing succeeded.
 using PassRegistryFunction =
     std::function<LogicalResult(OpPassManager &, StringRef options)>;
+using PassAllocatorFunction = std::function<std::unique_ptr<Pass>()>;
 
 /// A special type used by transformation passes to provide an address that can
 /// act as a unique identifier during pass registration.
@@ -56,7 +57,7 @@ public:
 
 protected:
   PassRegistryEntry(StringRef arg, StringRef description,
-                    PassRegistryFunction builder)
+                    const PassRegistryFunction &builder)
       : arg(arg), description(description), builder(builder) {}
 
 private:
@@ -74,7 +75,7 @@ private:
 class PassPipelineInfo : public PassRegistryEntry {
 public:
   PassPipelineInfo(StringRef arg, StringRef description,
-                   PassRegistryFunction builder)
+                   const PassRegistryFunction &builder)
       : PassRegistryEntry(arg, description, builder) {}
 };
 
@@ -84,8 +85,7 @@ public:
   /// PassInfo constructor should not be invoked directly, instead use
   /// PassRegistration or registerPass.
   PassInfo(StringRef arg, StringRef description, const PassID *passID,
-           PassRegistryFunction allocator)
-      : PassRegistryEntry(arg, description, allocator) {}
+           const PassAllocatorFunction &allocator);
 };
 
 //===----------------------------------------------------------------------===//
@@ -100,80 +100,28 @@ void registerPassPipeline(StringRef arg, StringRef description,
 /// Register a specific dialect pass allocator function with the system,
 /// typically used through the PassRegistration template.
 void registerPass(StringRef arg, StringRef description, const PassID *passID,
-                  const PassRegistryFunction &function);
-
-namespace detail {
-
-// Calls `pm.addPass(std::move(pass))` to avoid including the PassManager
-// header. Only used in `makePassRegistryFunction`.
-void addPassToPassManager(OpPassManager &pm, std::unique_ptr<Pass> pass);
-
-// Helper function which constructs a PassRegistryFunction that parses options
-// into a struct of type `Options` and then calls constructor(options) to
-// build the pass.
-template <typename Options, typename PassConstructor>
-PassRegistryFunction makePassRegistryFunction(PassConstructor constructor) {
-  return [=](OpPassManager &pm, StringRef optionsStr) {
-    Options options;
-    if (failed(options.parseFromString(optionsStr)))
-      return failure();
-    addPassToPassManager(pm, constructor(options));
-    return success();
-  };
-}
-
-} // end namespace detail
+                  const PassAllocatorFunction &function);
 
 /// PassRegistration provides a global initializer that registers a Pass
-/// allocation routine for a concrete pass instance.  The third argument is
+/// allocation routine for a concrete pass instance. The third argument is
 /// optional and provides a callback to construct a pass that does not have
 /// a default constructor.
 ///
 /// Usage:
 ///
-///   // At namespace scope.
+///   /// At namespace scope.
 ///   static PassRegistration<MyPass> reg("my-pass", "My Pass Description.");
 ///
-///   // Same, but also providing an Options struct.
-///   static PassRegistration<MyPass, MyPassOptions> reg("my-pass", "Docs...");
-template <typename ConcretePass, typename Options = EmptyPassOptions>
-struct PassRegistration {
+template <typename ConcretePass> struct PassRegistration {
 
   PassRegistration(StringRef arg, StringRef description,
-                   const std::function<std::unique_ptr<Pass>(const Options &)>
-                       &constructor) {
-    registerPass(arg, description, PassID::getID<ConcretePass>(),
-                 detail::makePassRegistryFunction<Options>(constructor));
+                   const PassAllocatorFunction &constructor) {
+    registerPass(arg, description, PassID::getID<ConcretePass>(), constructor);
   }
 
-  PassRegistration(StringRef arg, StringRef description) {
-    registerPass(
-        arg, description, PassID::getID<ConcretePass>(),
-        detail::makePassRegistryFunction<Options>([](const Options &options) {
-          return std::make_unique<ConcretePass>(options);
-        }));
-  }
-};
-
-/// Convenience specialization of PassRegistration for EmptyPassOptions that
-/// does not pass an empty options struct to the pass constructor.
-template <typename ConcretePass>
-struct PassRegistration<ConcretePass, EmptyPassOptions> {
-  PassRegistration(StringRef arg, StringRef description,
-                   const std::function<std::unique_ptr<Pass>()> &constructor) {
-    registerPass(
-        arg, description, PassID::getID<ConcretePass>(),
-        detail::makePassRegistryFunction<EmptyPassOptions>(
-            [=](const EmptyPassOptions &options) { return constructor(); }));
-  }
-
-  PassRegistration(StringRef arg, StringRef description) {
-    registerPass(arg, description, PassID::getID<ConcretePass>(),
-                 detail::makePassRegistryFunction<EmptyPassOptions>(
-                     [](const EmptyPassOptions &options) {
-                       return std::make_unique<ConcretePass>();
-                     }));
-  }
+  PassRegistration(StringRef arg, StringRef description)
+      : PassRegistration(arg, description,
+                         [] { return std::make_unique<ConcretePass>(); }) {}
 };
 
 /// PassPipelineRegistration provides a global initializer that registers a Pass
@@ -189,7 +137,8 @@ struct PassRegistration<ConcretePass, EmptyPassOptions> {
 ///
 ///   static PassPipelineRegistration Unused("unused", "Unused pass",
 ///                                          pipelineBuilder);
-template <typename Options = EmptyPassOptions> struct PassPipelineRegistration {
+template <typename Options = EmptyPipelineOptions>
+struct PassPipelineRegistration {
   PassPipelineRegistration(
       StringRef arg, StringRef description,
       std::function<void(OpPassManager &, const Options &options)> builder) {
@@ -206,7 +155,7 @@ template <typename Options = EmptyPassOptions> struct PassPipelineRegistration {
 
 /// Convenience specialization of PassPipelineRegistration for EmptyPassOptions
 /// that does not pass an empty options struct to the pass builder function.
-template <> struct PassPipelineRegistration<EmptyPassOptions> {
+template <> struct PassPipelineRegistration<EmptyPipelineOptions> {
   PassPipelineRegistration(StringRef arg, StringRef description,
                            std::function<void(OpPassManager &)> builder) {
     registerPassPipeline(arg, description,
index 1150960..68392c3 100644 (file)
@@ -35,17 +35,17 @@ namespace {
 /// 2) Lower the body of the spirv::ModuleOp.
 class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
 public:
-  GPUToSPIRVPass(ArrayRef<int64_t> workGroupSize)
-      : workGroupSize(workGroupSize.begin(), workGroupSize.end()) {}
+  GPUToSPIRVPass() = default;
+  GPUToSPIRVPass(const GPUToSPIRVPass &) {}
+  GPUToSPIRVPass(ArrayRef<int64_t> workGroupSize) {
+    this->workGroupSize = workGroupSize;
+  }
+
   void runOnModule() override;
 
 private:
-  SmallVector<int64_t, 3> workGroupSize;
-};
-
-/// Command line option to specify the workgroup size.
-struct GPUToSPIRVPassOptions : public PassOptions<GPUToSPIRVPassOptions> {
-  List<unsigned> workGroupSize{
+  /// Command line option to specify the workgroup size.
+  ListOption<int64_t> workGroupSize{
       *this, "workgroup-size",
       llvm::cl::desc(
           "Workgroup Sizes in the SPIR-V module for x, followed by y, followed "
@@ -92,11 +92,5 @@ mlir::createConvertGPUToSPIRVPass(ArrayRef<int64_t> workGroupSize) {
   return std::make_unique<GPUToSPIRVPass>(workGroupSize);
 }
 
-static PassRegistration<GPUToSPIRVPass, GPUToSPIRVPassOptions>
-    pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect",
-         [](const GPUToSPIRVPassOptions &passOptions) {
-           SmallVector<int64_t, 3> workGroupSize;
-           workGroupSize.assign(passOptions.workGroupSize.begin(),
-                                passOptions.workGroupSize.end());
-           return std::make_unique<GPUToSPIRVPass>(workGroupSize);
-         });
+static PassRegistration<GPUToSPIRVPass>
+    pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect");
index 22e58cc..8877cc5 100644 (file)
@@ -36,6 +36,17 @@ using namespace mlir::detail;
 /// single .o file.
 void Pass::anchor() {}
 
+/// Attempt to initialize the options of this pass from the given string.
+LogicalResult Pass::initializeOptions(StringRef options) {
+  return passOptions.parseFromString(options);
+}
+
+/// Copy the option values from 'other', which is another instance of this
+/// pass.
+void Pass::copyOptionValuesFrom(const Pass *other) {
+  passOptions.copyOptionValuesFrom(other->passOptions);
+}
+
 /// Prints out the pass in the textual representation of pipelines. If this is
 /// an adaptor pass, print with the op_name(sub_pass,...) format.
 void Pass::printAsTextualPipeline(raw_ostream &os) {
@@ -46,11 +57,14 @@ void Pass::printAsTextualPipeline(raw_ostream &os) {
       pm.printAsTextualPipeline(os);
       os << ")";
     });
-  } else if (const PassInfo *info = lookupPassInfo()) {
+    return;
+  }
+  // Otherwise, print the pass argument followed by its options.
+  if (const PassInfo *info = lookupPassInfo())
     os << info->getPassArgument();
-  } else {
+  else
     os << getName();
-  }
+  passOptions.print(os);
 }
 
 /// Forwarding function to execute this pass.
index 93753d3..1c5193d 100644 (file)
@@ -24,10 +24,15 @@ static llvm::ManagedStatic<DenseMap<const PassID *, PassInfo>> passRegistry;
 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
     passPipelineRegistry;
 
-// Helper to avoid exposing OpPassManager.
-void mlir::detail::addPassToPassManager(OpPassManager &pm,
-                                        std::unique_ptr<Pass> pass) {
-  pm.addPass(std::move(pass));
+/// Utility to create a default registry function from a pass instance.
+static PassRegistryFunction
+buildDefaultRegistryFn(const PassAllocatorFunction &allocator) {
+  return [=](OpPassManager &pm, StringRef options) {
+    std::unique_ptr<Pass> pass = allocator();
+    LogicalResult result = pass->initializeOptions(options);
+    pm.addPass(std::move(pass));
+    return result;
+  };
 }
 
 //===----------------------------------------------------------------------===//
@@ -46,9 +51,13 @@ void mlir::registerPassPipeline(StringRef arg, StringRef description,
 // PassInfo
 //===----------------------------------------------------------------------===//
 
+PassInfo::PassInfo(StringRef arg, StringRef description, const PassID *passID,
+                   const PassAllocatorFunction &allocator)
+    : PassRegistryEntry(arg, description, buildDefaultRegistryFn(allocator)) {}
+
 void mlir::registerPass(StringRef arg, StringRef description,
                         const PassID *passID,
-                        const PassRegistryFunction &function) {
+                        const PassAllocatorFunction &function) {
   PassInfo passInfo(arg, description, passID, function);
   bool inserted = passRegistry->try_emplace(passID, passInfo).second;
   assert(inserted && "Pass registered multiple times");
@@ -67,7 +76,19 @@ const PassInfo *mlir::Pass::lookupPassInfo(const PassID *passID) {
 // PassOptions
 //===----------------------------------------------------------------------===//
 
-LogicalResult PassOptionsBase::parseFromString(StringRef options) {
+/// Out of line virtual function to provide home for the class.
+void detail::PassOptions::OptionBase::anchor() {}
+
+/// Copy the option values from 'other'.
+void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
+  assert(options.size() == other.options.size());
+  if (options.empty())
+    return;
+  for (auto optionsIt : llvm::zip(options, other.options))
+    std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
+}
+
+LogicalResult detail::PassOptions::parseFromString(StringRef options) {
   // TODO(parkers): Handle escaping strings.
   // NOTE: `options` is modified in place to always refer to the unprocessed
   // part of the string.
@@ -99,7 +120,6 @@ LogicalResult PassOptionsBase::parseFromString(StringRef options) {
     auto it = OptionsMap.find(key);
     if (it == OptionsMap.end()) {
       llvm::errs() << "<Pass-Options-Parser>: no such option " << key << "\n";
-
       return failure();
     }
     if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
@@ -109,6 +129,28 @@ LogicalResult PassOptionsBase::parseFromString(StringRef options) {
   return success();
 }
 
+/// Print the options held by this struct in a form that can be parsed via
+/// 'parseFromString'.
+void detail::PassOptions::print(raw_ostream &os) {
+  // If there are no options, there is nothing left to do.
+  if (OptionsMap.empty())
+    return;
+
+  // Sort the options to make the ordering deterministic.
+  SmallVector<OptionBase *, 4> orderedOptions(options.begin(), options.end());
+  llvm::array_pod_sort(orderedOptions.begin(), orderedOptions.end(),
+                       [](OptionBase *const *lhs, OptionBase *const *rhs) {
+                         return (*lhs)->getArgStr().compare(
+                             (*rhs)->getArgStr());
+                       });
+
+  // Interleave the options with ' '.
+  os << '{';
+  interleave(
+      orderedOptions, os, [&](OptionBase *option) { option->print(os); }, " ");
+  os << '}';
+}
+
 //===----------------------------------------------------------------------===//
 // TextualPassPipeline Parser
 //===----------------------------------------------------------------------===//
index 02452a3..bfb24af 100644 (file)
@@ -13,6 +13,6 @@
 // CHECK_ERROR_4: 'notaninteger' value invalid for integer argument
 // CHECK_ERROR_5: string option: may only occur zero or one times
 
-// CHECK_1: test-options-pass{list=1,2,3,4,5 string-list=a,b,c,d string=some_value}
-// CHECK_2: test-options-pass{list=1 string-list=a,b}
-// CHECK_3: module(func(test-options-pass{list=3}), func(test-options-pass{list=1,2,3,4}))
+// CHECK_1: test-options-pass{list=1,2,3,4,5 string=some_value string-list=a,b,c,d}
+// CHECK_2: test-options-pass{list=1 string= string-list=a,b}
+// CHECK_3: module(func(test-options-pass{list=3 string= string-list=}), func(test-options-pass{list=1,2,3,4 string= string-list=}))
index 2e81163..cc926e1 100644 (file)
@@ -21,43 +21,34 @@ struct TestFunctionPass : public FunctionPass<TestFunctionPass> {
 };
 class TestOptionsPass : public FunctionPass<TestOptionsPass> {
 public:
-  struct Options : public PassOptions<Options> {
-    List<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated,
-                         llvm::cl::desc("Example list option")};
-    List<std::string> stringListOption{
+  struct Options : public PassPipelineOptions<Options> {
+    ListOption<int> listOption{*this, "list",
+                               llvm::cl::MiscFlags::CommaSeparated,
+                               llvm::cl::desc("Example list option")};
+    ListOption<std::string> stringListOption{
         *this, "string-list", llvm::cl::MiscFlags::CommaSeparated,
         llvm::cl::desc("Example string list option")};
     Option<std::string> stringOption{*this, "string",
                                      llvm::cl::desc("Example string option")};
   };
+  TestOptionsPass() = default;
+  TestOptionsPass(const TestOptionsPass &) {}
   TestOptionsPass(const Options &options) {
-    listOption.assign(options.listOption.begin(), options.listOption.end());
-    stringOption = options.stringOption;
-    stringListOption.assign(options.stringListOption.begin(),
-                            options.stringListOption.end());
-  }
-
-  void printAsTextualPipeline(raw_ostream &os) final {
-    os << "test-options-pass{";
-    if (!listOption.empty()) {
-      os << "list=";
-      // Not interleaveComma to avoid spaces between the elements.
-      interleave(listOption, os, ",");
-    }
-    if (!stringListOption.empty()) {
-      os << " string-list=";
-      interleave(stringListOption, os, ",");
-    }
-    if (!stringOption.empty())
-      os << " string=" << stringOption;
-    os << "}";
+    listOption->assign(options.listOption.begin(), options.listOption.end());
+    stringOption.setValue(options.stringOption);
+    stringListOption->assign(options.stringListOption.begin(),
+                             options.stringListOption.end());
   }
 
   void runOnFunction() final {}
 
-  SmallVector<int64_t, 4> listOption;
-  SmallVector<std::string, 4> stringListOption;
-  std::string stringOption;
+  ListOption<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated,
+                             llvm::cl::desc("Example list option")};
+  ListOption<std::string> stringListOption{
+      *this, "string-list", llvm::cl::MiscFlags::CommaSeparated,
+      llvm::cl::desc("Example string list option")};
+  Option<std::string> stringOption{*this, "string",
+                                   llvm::cl::desc("Example string option")};
 };
 
 /// A test pass that always aborts to enable testing the crash recovery
@@ -97,7 +88,7 @@ static void testNestedPipelineTextual(OpPassManager &pm) {
   (void)parsePassPipeline("test-pm-nested-pipeline", pm);
 }
 
-static PassRegistration<TestOptionsPass, TestOptionsPass::Options>
+static PassRegistration<TestOptionsPass>
     reg("test-options-pass", "Test options parsing capabilities");
 
 static PassRegistration<TestModulePass>
index 7b0cdca..e793ee5 100644 (file)
@@ -25,18 +25,10 @@ namespace {
 class SimpleParametricLoopTilingPass
     : public FunctionPass<SimpleParametricLoopTilingPass> {
 public:
-  struct Options : public PassOptions<Options> {
-    List<int> clOuterLoopSizes{
-        *this, "test-outer-loop-sizes", llvm::cl::MiscFlags::CommaSeparated,
-        llvm::cl::desc(
-            "fixed number of iterations that the outer loops should have")};
-  };
-
-  explicit SimpleParametricLoopTilingPass(ArrayRef<int64_t> outerLoopSizes)
-      : sizes(outerLoopSizes.begin(), outerLoopSizes.end()) {}
-  explicit SimpleParametricLoopTilingPass(const Options &options) {
-    sizes.assign(options.clOuterLoopSizes.begin(),
-                 options.clOuterLoopSizes.end());
+  SimpleParametricLoopTilingPass() = default;
+  SimpleParametricLoopTilingPass(const SimpleParametricLoopTilingPass &) {}
+  explicit SimpleParametricLoopTilingPass(ArrayRef<int64_t> outerLoopSizes) {
+    sizes = outerLoopSizes;
   }
 
   void runOnFunction() override {
@@ -49,7 +41,10 @@ public:
     });
   }
 
-  SmallVector<int64_t, 4> sizes;
+  ListOption<int64_t> sizes{
+      *this, "test-outer-loop-sizes", llvm::cl::MiscFlags::CommaSeparated,
+      llvm::cl::desc(
+          "fixed number of iterations that the outer loops should have")};
 };
 } // end namespace
 
@@ -58,8 +53,7 @@ mlir::createSimpleParametricTilingPass(ArrayRef<int64_t> outerLoopSizes) {
   return std::make_unique<SimpleParametricLoopTilingPass>(outerLoopSizes);
 }
 
-static PassRegistration<SimpleParametricLoopTilingPass,
-                        SimpleParametricLoopTilingPass::Options>
+static PassRegistration<SimpleParametricLoopTilingPass>
     reg("test-extract-fixed-outer-loops",
         "test application of parametric tiling to the outer loops so that the "
         "ranges of outer loops become static");