[mlir] [VectorOps] Expose lowering pass options programmatically
authoraartbik <ajcbik@google.com>
Wed, 8 Jul 2020 21:58:04 +0000 (14:58 -0700)
committeraartbik <ajcbik@google.com>
Wed, 8 Jul 2020 21:58:07 +0000 (14:58 -0700)
The ConvertVectorToLLVM pass defines options that can be passed
on the command line (currently only reassociation of FP reductions
through -convert-vector-to-llvm='reassociate-fp-reductions). This
CL enables setting these options programmatically (forward looking
to more options than just reassociation, as well as setting the
values from code rather than command line).

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D83420

mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

index cdff188..82aa828 100644 (file)
@@ -16,6 +16,18 @@ class ModuleOp;
 template <typename T>
 class OperationPass;
 
+/// Options to control Vector to LLVM lowering.
+///
+/// This should kept in sync with VectorToLLVM options defined for the
+/// ConvertVectorToLLVM pass in include/mlir/Conversion/Passes.td
+struct LowerVectorToLLVMOptions {
+  bool reassociateFPReductions = false;
+  LowerVectorToLLVMOptions &setReassociateFPReductions(bool r) {
+    reassociateFPReductions = r;
+    return *this;
+  }
+};
+
 /// Collect a set of patterns to convert from Vector contractions to LLVM Matrix
 /// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics
 /// will be needed when invoking LLVM.
@@ -28,7 +40,8 @@ void populateVectorToLLVMConversionPatterns(
     bool reassociateFPReductions = false);
 
 /// Create a pass to convert vector operations to the LLVMIR dialect.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToLLVMPass();
+std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToLLVMPass(
+    const LowerVectorToLLVMOptions &options = LowerVectorToLLVMOptions());
 
 } // namespace mlir
 
index 9a66daf..96a8fa4 100644 (file)
@@ -1180,6 +1180,9 @@ void mlir::populateVectorToLLVMMatrixConversionPatterns(
 namespace {
 struct LowerVectorToLLVMPass
     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
+  LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
+    this->reassociateFPReductions = options.reassociateFPReductions;
+  }
   void runOnOperation() override;
 };
 } // namespace
@@ -1210,6 +1213,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
   }
 }
 
-std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertVectorToLLVMPass() {
-  return std::make_unique<LowerVectorToLLVMPass>();
+std::unique_ptr<OperationPass<ModuleOp>>
+mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
+  return std::make_unique<LowerVectorToLLVMPass>(options);
 }