[mlir][Pass][NFC] Replace usages of ModulePass with OperationPass<ModuleOp>
authorRiver Riddle <riddleriver@gmail.com>
Tue, 7 Apr 2020 20:55:34 +0000 (13:55 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 7 Apr 2020 21:08:52 +0000 (14:08 -0700)
ModulePass doesn't provide any special utilities and thus doesn't give enough benefit to warrant a special pass class. This revision replaces all usages with the more general OperationPass.

Differential Revision: https://reviews.llvm.org/D77339

26 files changed:
mlir/docs/Tutorials/Toy/Ch-6.md
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
mlir/include/mlir/Pass/Pass.h
mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
mlir/lib/Transforms/OpStats.cpp
mlir/lib/Transforms/ViewOpGraph.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/test/lib/IR/TestFunc.cpp
mlir/test/lib/IR/TestSideEffects.cpp
mlir/test/lib/IR/TestSymbolUses.cpp
mlir/test/lib/Pass/TestPassManager.cpp
mlir/test/lib/Transforms/TestAllReduceLowering.cpp
mlir/test/lib/Transforms/TestCallGraph.cpp
mlir/test/lib/Transforms/TestOpaqueLoc.cpp

index 0444d2a..e1dfc00 100644 (file)
@@ -105,7 +105,7 @@ We want to completely lower to LLVM, so we use a `FullConversion`. This ensures
 that only legal operations will remain after the conversion.
 
 ```c++
-  mlir::ModuleOp module = getModule();
+  mlir::ModuleOp module = getOperation();
   if (mlir::failed(mlir::applyFullConversion(module, target, patterns,
                                              &typeConverter)))
     signalPassFailure();
index f6dcba2..99465d3 100644 (file)
@@ -153,12 +153,13 @@ private:
 //===----------------------------------------------------------------------===//
 
 namespace {
-struct ToyToLLVMLoweringPass : public ModulePass<ToyToLLVMLoweringPass> {
-  void runOnModule() final;
+struct ToyToLLVMLoweringPass
+    : public OperationPass<ToyToLLVMLoweringPass, ModuleOp> {
+  void runOnOperation() final;
 };
 } // end anonymous namespace
 
-void ToyToLLVMLoweringPass::runOnModule() {
+void ToyToLLVMLoweringPass::runOnOperation() {
   // The first thing to define is the conversion target. This will define the
   // final target for this lowering. For this lowering, we are only targeting
   // the LLVM dialect.
@@ -191,7 +192,7 @@ void ToyToLLVMLoweringPass::runOnModule() {
 
   // We want to completely lower to LLVM, so we use a `FullConversion`. This
   // ensures that only legal operations will remain after the conversion.
-  auto module = getModule();
+  auto module = getOperation();
   if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
     signalPassFailure();
 }
index f6dcba2..99465d3 100644 (file)
@@ -153,12 +153,13 @@ private:
 //===----------------------------------------------------------------------===//
 
 namespace {
-struct ToyToLLVMLoweringPass : public ModulePass<ToyToLLVMLoweringPass> {
-  void runOnModule() final;
+struct ToyToLLVMLoweringPass
+    : public OperationPass<ToyToLLVMLoweringPass, ModuleOp> {
+  void runOnOperation() final;
 };
 } // end anonymous namespace
 
-void ToyToLLVMLoweringPass::runOnModule() {
+void ToyToLLVMLoweringPass::runOnOperation() {
   // The first thing to define is the conversion target. This will define the
   // final target for this lowering. For this lowering, we are only targeting
   // the LLVM dialect.
@@ -191,7 +192,7 @@ void ToyToLLVMLoweringPass::runOnModule() {
 
   // We want to completely lower to LLVM, so we use a `FullConversion`. This
   // ensures that only legal operations will remain after the conversion.
-  auto module = getModule();
+  auto module = getOperation();
   if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
     signalPassFailure();
 }
index 80c4ddf..c1eec4f 100644 (file)
@@ -341,24 +341,9 @@ template <typename T> struct FunctionPass : public OperationPass<T, FuncOp> {
       runOnFunction();
   }
 
-  /// Return the current module being transformed.
+  /// Return the current function being transformed.
   FuncOp getFunction() { return this->getOperation(); }
 };
-
-/// A model for providing module pass specific utilities.
-///
-/// Derived module passes are expected to provide the following:
-///   - A 'void runOnModule()' method.
-template <typename T> struct ModulePass : public OperationPass<T, ModuleOp> {
-  /// The polymorphic API that runs the pass over the currently held module.
-  virtual void runOnModule() = 0;
-
-  /// The polymorphic API that runs the pass over the currently held operation.
-  void runOnOperation() final { runOnModule(); }
-
-  /// Return the current module being transformed.
-  ModuleOp getModule() { return this->getOperation(); }
-};
 } // end namespace mlir
 
 #endif // MLIR_PASS_PASS_H
index 91f3cc9..08b187f 100644 (file)
@@ -163,16 +163,17 @@ void mlir::populateAVX512ToLLVMConversionPatterns(
 }
 
 namespace {
-struct ConvertAVX512ToLLVMPass : public ModulePass<ConvertAVX512ToLLVMPass> {
+struct ConvertAVX512ToLLVMPass
+    : public OperationPass<ConvertAVX512ToLLVMPass, ModuleOp> {
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertAVX512ToLLVM
 #include "mlir/Conversion/Passes.h.inc"
 
-  void runOnModule() override;
+  void runOnOperation() override;
 };
 } // namespace
 
-void ConvertAVX512ToLLVMPass::runOnModule() {
+void ConvertAVX512ToLLVMPass::runOnOperation() {
   // Convert to the LLVM IR dialect.
   OwningRewritePatternList patterns;
   LLVMTypeConverter converter(&getContext());
@@ -186,8 +187,8 @@ void ConvertAVX512ToLLVMPass::runOnModule() {
   target.addIllegalDialect<avx512::AVX512Dialect>();
   target.addDynamicallyLegalOp<FuncOp>(
       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
-  if (failed(
-          applyPartialConversion(getModule(), target, patterns, &converter))) {
+  if (failed(applyPartialConversion(getOperation(), target, patterns,
+                                    &converter))) {
     signalPassFailure();
   }
 }
index 38c092a..71fe129 100644 (file)
@@ -61,7 +61,7 @@ namespace {
 ///
 /// Intermediate data structures are allocated on the stack.
 class GpuLaunchFuncToCudaCallsPass
-    : public ModulePass<GpuLaunchFuncToCudaCallsPass> {
+    : public OperationPass<GpuLaunchFuncToCudaCallsPass, ModuleOp> {
 private:
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertGpuLaunchFuncToCudaCalls
@@ -126,20 +126,19 @@ private:
 
 public:
   // Run the dialect converter on the module.
-  void runOnModule() override {
+  void runOnOperation() override {
     // Cache the LLVMDialect for the current module.
     llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
     // Cache the used LLVM types.
     initializeCachedTypes();
 
-    getModule().walk([this](mlir::gpu::LaunchFuncOp op) {
-      translateGpuLaunchCalls(op);
-    });
+    getOperation().walk(
+        [this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); });
 
     // GPU kernel modules are no longer necessary since we have a global
     // constant with the CUBIN data.
     for (auto m :
-         llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>()))
+         llvm::make_early_inc_range(getOperation().getOps<gpu::GPUModuleOp>()))
       m.erase();
   }
 
@@ -160,7 +159,7 @@ private:
 // The types in comments give the actual types expected/returned but the API
 // uses void pointers. This is fine as they have the same linkage in C.
 void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
-  ModuleOp module = getModule();
+  ModuleOp module = getOperation();
   OpBuilder builder(module.getBody()->getTerminator());
   if (!module.lookupSymbol(cuModuleLoadName)) {
     builder.create<LLVM::LLVMFuncOp>(
@@ -391,7 +390,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
                                                builder.getI32IntegerAttr(0));
   // Create an LLVM global with CUBIN extracted from the kernel annotation and
   // obtain a pointer to the first byte in it.
-  auto kernelModule = getModule().lookupSymbol<gpu::GPUModuleOp>(
+  auto kernelModule = getOperation().lookupSymbol<gpu::GPUModuleOp>(
       launchOp.getKernelModuleName());
   assert(kernelModule && "expected a kernel module");
 
@@ -412,7 +411,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
   // in the called helper function.
   auto cuModule = allocatePointer(builder, loc);
   auto cuModuleLoad =
-      getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleLoadName);
+      getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleLoadName);
   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
                                builder.getSymbolRefAttr(cuModuleLoad),
                                ArrayRef<Value>{cuModule, data});
@@ -423,20 +422,20 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
   auto kernelName = generateKernelNameConstant(launchOp.kernel(), loc, builder);
   auto cuFunction = allocatePointer(builder, loc);
   auto cuModuleGetFunction =
-      getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleGetFunctionName);
+      getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleGetFunctionName);
   builder.create<LLVM::CallOp>(
       loc, ArrayRef<Type>{getCUResultType()},
       builder.getSymbolRefAttr(cuModuleGetFunction),
       ArrayRef<Value>{cuFunction, cuOwningModuleRef, kernelName});
   // Grab the global stream needed for execution.
   auto cuGetStreamHelper =
-      getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuGetStreamHelperName);
+      getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuGetStreamHelperName);
   auto cuStream = builder.create<LLVM::CallOp>(
       loc, ArrayRef<Type>{getPointerType()},
       builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef<Value>{});
   // Invoke the function with required arguments.
   auto cuLaunchKernel =
-      getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuLaunchKernelName);
+      getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuLaunchKernelName);
   auto cuFunctionRef =
       builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction);
   auto paramsArray = setupParamsArray(launchOp, builder);
@@ -458,7 +457,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
                       nullpointer /* extra */});
   // Sync on the stream to make it synchronous.
   auto cuStreamSync =
-      getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuStreamSynchronizeName);
+      getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuStreamSynchronizeName);
   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
                                builder.getSymbolRefAttr(cuStreamSync),
                                ArrayRef<Value>(cuStream.getResult(0)));
index 1102ef1..edee502 100644 (file)
@@ -33,18 +33,18 @@ namespace {
 /// replace it).
 ///
 /// 2) Lower the body of the spirv::ModuleOp.
-struct GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
+struct GPUToSPIRVPass : public OperationPass<GPUToSPIRVPass, ModuleOp> {
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertGpuToSPIRV
 #include "mlir/Conversion/Passes.h.inc"
 
-  void runOnModule() override;
+  void runOnOperation() override;
 };
 } // namespace
 
-void GPUToSPIRVPass::runOnModule() {
+void GPUToSPIRVPass::runOnOperation() {
   MLIRContext *context = &getContext();
-  ModuleOp module = getModule();
+  ModuleOp module = getOperation();
 
   SmallVector<Operation *, 1> kernelModules;
   OpBuilder builder(context);
index 823860b..cbcfd74 100644 (file)
@@ -38,13 +38,13 @@ namespace {
 /// function and attaching binary data and entry point name as an attributes to
 /// created vulkan launch call op.
 class ConvertGpuLaunchFuncToVulkanLaunchFunc
-    : public ModulePass<ConvertGpuLaunchFuncToVulkanLaunchFunc> {
+    : public OperationPass<ConvertGpuLaunchFuncToVulkanLaunchFunc, ModuleOp> {
 public:
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertGpuLaunchFuncToVulkanLaunchFunc
 #include "mlir/Conversion/Passes.h.inc"
 
-  void runOnModule() override;
+  void runOnOperation() override;
 
 private:
   /// Creates a SPIR-V binary shader from the given `module` using
@@ -68,14 +68,13 @@ private:
   /// operand is unsupported by Vulkan runtime.
   LogicalResult declareVulkanLaunchFunc(Location loc,
                                         gpu::LaunchFuncOp launchOp);
-
 };
 
 } // anonymous namespace
 
-void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnModule() {
+void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnOperation() {
   bool done = false;
-  getModule().walk([this, &done](gpu::LaunchFuncOp op) {
+  getOperation().walk([this, &done](gpu::LaunchFuncOp op) {
     if (done) {
       op.emitError("should only contain one 'gpu::LaunchFuncOp' op");
       return signalPassFailure();
@@ -86,17 +85,17 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnModule() {
 
   // Erase `gpu::GPUModuleOp` and `spirv::Module` operations.
   for (auto gpuModule :
-       llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>()))
+       llvm::make_early_inc_range(getOperation().getOps<gpu::GPUModuleOp>()))
     gpuModule.erase();
 
   for (auto spirvModule :
-       llvm::make_early_inc_range(getModule().getOps<spirv::ModuleOp>()))
+       llvm::make_early_inc_range(getOperation().getOps<spirv::ModuleOp>()))
     spirvModule.erase();
 }
 
 LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc(
     Location loc, gpu::LaunchFuncOp launchOp) {
-  OpBuilder builder(getModule().getBody()->getTerminator());
+  OpBuilder builder(getOperation().getBody()->getTerminator());
   // TODO: Workgroup size is written into the kernel. So to properly modelling
   // vulkan launch, we cannot have the local workgroup size configuration here.
   SmallVector<Type, 8> vulkanLaunchTypes{launchOp.getOperandTypes()};
@@ -138,7 +137,7 @@ LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::createBinaryShader(
 
 void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
     gpu::LaunchFuncOp launchOp) {
-  ModuleOp module = getModule();
+  ModuleOp module = getOperation();
   OpBuilder builder(launchOp);
   Location loc = launchOp.getLoc();
 
index ebc8ded..2daa130 100644 (file)
@@ -58,7 +58,7 @@ namespace {
 /// * deinitVulkan         -- deinitializes vulkan runtime
 ///
 class VulkanLaunchFuncToVulkanCallsPass
-    : public ModulePass<VulkanLaunchFuncToVulkanCallsPass> {
+    : public OperationPass<VulkanLaunchFuncToVulkanCallsPass, ModuleOp> {
 private:
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertVulkanLaunchFuncToVulkanCalls
@@ -150,7 +150,7 @@ private:
   LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank);
 
 public:
-  void runOnModule() override;
+  void runOnOperation() override;
 
 private:
   LLVM::LLVMDialect *llvmDialect;
@@ -169,18 +169,18 @@ private:
 
 } // anonymous namespace
 
-void VulkanLaunchFuncToVulkanCallsPass::runOnModule() {
+void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
   initializeCachedTypes();
 
   // Collect SPIR-V attributes such as `spirv_blob` and
   // `spirv_entry_point_name`.
-  getModule().walk([this](LLVM::CallOp op) {
+  getOperation().walk([this](LLVM::CallOp op) {
     if (isVulkanLaunchCallOp(op))
       collectSPIRVAttributes(op);
   });
 
   // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
-  getModule().walk([this](LLVM::CallOp op) {
+  getOperation().walk([this](LLVM::CallOp op) {
     if (isCInterfaceVulkanLaunchCallOp(op))
       translateVulkanLaunchCall(op);
   });
@@ -278,7 +278,7 @@ VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor,
 }
 
 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
-  ModuleOp module = getModule();
+  ModuleOp module = getOperation();
   OpBuilder builder(module.getBody()->getTerminator());
 
   if (!module.lookupSymbol(kSetEntryPoint)) {
index 07c8111..99f106e 100644 (file)
@@ -561,17 +561,18 @@ void mlir::populateLinalgToLLVMConversionPatterns(
 }
 
 namespace {
-struct ConvertLinalgToLLVMPass : public ModulePass<ConvertLinalgToLLVMPass> {
+struct ConvertLinalgToLLVMPass
+    : public OperationPass<ConvertLinalgToLLVMPass, ModuleOp> {
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertLinalgToLLVM
 #include "mlir/Conversion/Passes.h.inc"
 
-  void runOnModule() override;
+  void runOnOperation() override;
 };
 } // namespace
 
-void ConvertLinalgToLLVMPass::runOnModule() {
-  auto module = getModule();
+void ConvertLinalgToLLVMPass::runOnOperation() {
+  auto module = getOperation();
 
   // Convert to the LLVM IR dialect using the converter defined above.
   OwningRewritePatternList patterns;
index 0962746..4b66063 100644 (file)
@@ -16,18 +16,18 @@ using namespace mlir;
 
 namespace {
 /// A pass converting MLIR Linalg ops into SPIR-V ops.
-class LinalgToSPIRVPass : public ModulePass<LinalgToSPIRVPass> {
+class LinalgToSPIRVPass : public OperationPass<LinalgToSPIRVPass, ModuleOp> {
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertLinalgToSPIRV
 #include "mlir/Conversion/Passes.h.inc"
 
-  void runOnModule() override;
+  void runOnOperation() override;
 };
 } // namespace
 
-void LinalgToSPIRVPass::runOnModule() {
+void LinalgToSPIRVPass::runOnOperation() {
   MLIRContext *context = &getContext();
-  ModuleOp module = getModule();
+  ModuleOp module = getOperation();
 
   auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
   std::unique_ptr<ConversionTarget> target =
index 1e127a0..ef5dabf 100644 (file)
@@ -2847,7 +2847,7 @@ LLVMTypeConverter::promoteMemRefDescriptors(Location loc, ValueRange opOperands,
 
 namespace {
 /// A pass converting MLIR operations into the LLVM IR dialect.
-struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
+struct LLVMLoweringPass : public OperationPass<LLVMLoweringPass, ModuleOp> {
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertStandardToLLVM
 #include "mlir/Conversion/Passes.h.inc"
@@ -2863,16 +2863,16 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
   LLVMLoweringPass(const LLVMLoweringPass &pass) {}
 
   /// Run the dialect converter on the module.
-  void runOnModule() override {
+  void runOnOperation() override {
     if (useBarePtrCallConv && emitCWrappers) {
-      getModule().emitError()
+      getOperation().emitError()
           << "incompatible conversion options: bare-pointer calling convention "
              "and C wrapper emission";
       signalPassFailure();
       return;
     }
 
-    ModuleOp m = getModule();
+    ModuleOp m = getOperation();
 
     LLVMTypeConverterCustomization customs;
     customs.funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
index ab7dd85..86c8cd1 100644 (file)
@@ -22,18 +22,18 @@ using namespace mlir;
 namespace {
 /// A pass converting MLIR Standard operations into the SPIR-V dialect.
 class ConvertStandardToSPIRVPass
-    : public ModulePass<ConvertStandardToSPIRVPass> {
+    : public OperationPass<ConvertStandardToSPIRVPass, ModuleOp> {
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertStandardToSPIRV
 #include "mlir/Conversion/Passes.h.inc"
 
-  void runOnModule() override;
+  void runOnOperation() override;
 };
 } // namespace
 
-void ConvertStandardToSPIRVPass::runOnModule() {
+void ConvertStandardToSPIRVPass::runOnOperation() {
   MLIRContext *context = &getContext();
-  ModuleOp module = getModule();
+  ModuleOp module = getOperation();
 
   auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
   std::unique_ptr<ConversionTarget> target =
index d5a4f86..b2a1c44 100644 (file)
@@ -1118,23 +1118,24 @@ void mlir::populateVectorToLLVMMatrixConversionPatterns(
 }
 
 namespace {
-struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
+struct LowerVectorToLLVMPass
+    : public OperationPass<LowerVectorToLLVMPass, ModuleOp> {
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertVectorToLLVM
 #include "mlir/Conversion/Passes.h.inc"
 
-  void runOnModule() override;
+  void runOnOperation() override;
 };
 } // namespace
 
-void LowerVectorToLLVMPass::runOnModule() {
+void LowerVectorToLLVMPass::runOnOperation() {
   // Perform progressive lowering of operations on slices and
   // all contraction operations. Also applies folding and DCE.
   {
     OwningRewritePatternList patterns;
     populateVectorSlicesLoweringPatterns(patterns, &getContext());
     populateVectorContractLoweringPatterns(patterns, &getContext());
-    applyPatternsGreedily(getModule(), patterns);
+    applyPatternsGreedily(getOperation(), patterns);
   }
 
   // Convert to the LLVM IR dialect.
@@ -1148,8 +1149,8 @@ void LowerVectorToLLVMPass::runOnModule() {
   LLVMConversionTarget target(getContext());
   target.addDynamicallyLegalOp<FuncOp>(
       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
-  if (failed(
-          applyPartialConversion(getModule(), target, patterns, &converter))) {
+  if (failed(applyPartialConversion(getOperation(), target, patterns,
+                                    &converter))) {
     signalPassFailure();
   }
 }
index 2eadf87..daf9169 100644 (file)
@@ -214,16 +214,17 @@ namespace {
 /// The gpu.modules are intended to be compiled to a cubin blob independently in
 /// a separate pass. The external functions can then be annotated with the
 /// symbol of the cubin accessor function.
-class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
+class GpuKernelOutliningPass
+    : public OperationPass<GpuKernelOutliningPass, ModuleOp> {
 public:
 /// Include the generated pass utilities.
 #define GEN_PASS_GpuKernelOutlining
 #include "mlir/Dialect/GPU/Passes.h.inc"
 
-  void runOnModule() override {
-    SymbolTable symbolTable(getModule());
+  void runOnOperation() override {
+    SymbolTable symbolTable(getOperation());
     bool modified = false;
-    for (auto func : getModule().getOps<FuncOp>()) {
+    for (auto func : getOperation().getOps<FuncOp>()) {
       // Insert just after the function.
       Block::iterator insertPt(func.getOperation()->getNextNode());
       auto funcWalkResult = func.walk([&](gpu::LaunchOp op) {
@@ -255,8 +256,8 @@ public:
     // If any new module was inserted in this module, annotate this module as
     // a container module.
     if (modified)
-      getModule().setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
-                          UnitAttr::get(&getContext()));
+      getOperation().setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
+                             UnitAttr::get(&getContext()));
   }
 
 private:
@@ -267,7 +268,7 @@ private:
     // a SymbolTable by the caller. SymbolTable needs to be refactored to
     // prevent manual building of Ops with symbols in code using SymbolTables
     // and then this needs to use the OpBuilder.
-    auto context = getModule().getContext();
+    auto context = getOperation().getContext();
     Builder builder(context);
     OperationState state(kernelFunc.getLoc(),
                          gpu::GPUModuleOp::getOperationName());
index 79ed819..e462274 100644 (file)
@@ -80,14 +80,14 @@ static void populateSPIRVLayoutInfoPatterns(OwningRewritePatternList &patterns,
 
 namespace {
 class DecorateSPIRVCompositeTypeLayoutPass
-    : public ModulePass<DecorateSPIRVCompositeTypeLayoutPass> {
+    : public OperationPass<DecorateSPIRVCompositeTypeLayoutPass, ModuleOp> {
 private:
-  void runOnModule() override;
+  void runOnOperation() override;
 };
 } // namespace
 
-void DecorateSPIRVCompositeTypeLayoutPass::runOnModule() {
-  auto module = getModule();
+void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
+  auto module = getOperation();
   OwningRewritePatternList patterns;
   populateSPIRVLayoutInfoPatterns(patterns, module.getContext());
   ConversionTarget target(*(module.getContext()));
index b7832f5..2b519d6 100644 (file)
@@ -18,7 +18,7 @@
 using namespace mlir;
 
 namespace {
-struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> {
+struct PrintOpStatsPass : public OperationPass<PrintOpStatsPass, ModuleOp> {
 /// Include the generated pass utilities.
 #define GEN_PASS_PrintOpStats
 #include "mlir/Transforms/Passes.h.inc"
@@ -26,7 +26,7 @@ struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> {
   explicit PrintOpStatsPass(raw_ostream &os = llvm::errs()) : os(os) {}
 
   // Prints the resultant operation statistics post iterating over the module.
-  void runOnModule() override;
+  void runOnOperation() override;
 
   // Print summary of op stats.
   void printSummary();
@@ -37,11 +37,11 @@ private:
 };
 } // namespace
 
-void PrintOpStatsPass::runOnModule() {
+void PrintOpStatsPass::runOnOperation() {
   opCount.clear();
 
   // Compute the operation statistics for each function in the module.
-  for (auto &op : getModule())
+  for (auto &op : getOperation())
     op.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
   printSummary();
 }
index fcaff9a..c5d921d 100644 (file)
@@ -100,7 +100,7 @@ namespace {
 // PrintOpPass is simple pass to write graph per function.
 // Note: this is a module pass only to avoid interleaving on the same ostream
 // due to multi-threading over functions.
-struct PrintOpPass : public ModulePass<PrintOpPass> {
+struct PrintOpPass : public OperationPass<PrintOpPass, ModuleOp> {
 /// Include the generated pass utilities.
 #define GEN_PASS_PrintOpGraph
 #include "mlir/Transforms/Passes.h.inc"
@@ -140,7 +140,7 @@ struct PrintOpPass : public ModulePass<PrintOpPass> {
     }
   }
 
-  void runOnModule() override { processModule(getModule()); }
+  void runOnOperation() override { processModule(getOperation()); }
 
 private:
   raw_ostream &os;
index e6cc52d..6ccfa04 100644 (file)
@@ -398,13 +398,13 @@ struct TestTypeConverter : public TypeConverter {
 };
 
 struct TestLegalizePatternDriver
-    : public ModulePass<TestLegalizePatternDriver> {
+    : public OperationPass<TestLegalizePatternDriver, ModuleOp> {
   /// The mode of conversion to use with the driver.
   enum class ConversionMode { Analysis, Full, Partial };
 
   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
 
-  void runOnModule() override {
+  void runOnOperation() override {
     TestTypeConverter converter;
     mlir::OwningRewritePatternList patterns;
     populateWithGenerated(&getContext(), &patterns);
@@ -450,7 +450,8 @@ struct TestLegalizePatternDriver
 
     // Handle a partial conversion.
     if (mode == ConversionMode::Partial) {
-      (void)applyPartialConversion(getModule(), target, patterns, &converter);
+      (void)applyPartialConversion(getOperation(), target, patterns,
+                                   &converter);
       return;
     }
 
@@ -461,7 +462,7 @@ struct TestLegalizePatternDriver
         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
       });
 
-      (void)applyFullConversion(getModule(), target, patterns, &converter);
+      (void)applyFullConversion(getOperation(), target, patterns, &converter);
       return;
     }
 
@@ -470,7 +471,7 @@ struct TestLegalizePatternDriver
 
     // Analyze the convertible operations.
     DenseSet<Operation *> legalizedOps;
-    if (failed(applyAnalysisConversion(getModule(), target, patterns,
+    if (failed(applyAnalysisConversion(getOperation(), target, patterns,
                                        legalizedOps, &converter)))
       return signalPassFailure();
 
index 0e885c5..c1b9039 100644 (file)
@@ -13,9 +13,9 @@ using namespace mlir;
 
 namespace {
 /// This is a test pass for verifying FuncOp's eraseArgument method.
-struct TestFuncEraseArg : public ModulePass<TestFuncEraseArg> {
-  void runOnModule() override {
-    auto module = getModule();
+struct TestFuncEraseArg : public OperationPass<TestFuncEraseArg, ModuleOp> {
+  void runOnOperation() override {
+    auto module = getOperation();
 
     for (FuncOp func : module.getOps<FuncOp>()) {
       SmallVector<unsigned, 4> indicesToErase;
@@ -36,9 +36,9 @@ struct TestFuncEraseArg : public ModulePass<TestFuncEraseArg> {
 };
 
 /// This is a test pass for verifying FuncOp's setType method.
-struct TestFuncSetType : public ModulePass<TestFuncSetType> {
-  void runOnModule() override {
-    auto module = getModule();
+struct TestFuncSetType : public OperationPass<TestFuncSetType, ModuleOp> {
+  void runOnOperation() override {
+    auto module = getOperation();
     SymbolTable symbolTable(module);
 
     for (FuncOp func : module.getOps<FuncOp>()) {
index 9f52c42..a993485 100644 (file)
@@ -12,9 +12,9 @@
 using namespace mlir;
 
 namespace {
-struct SideEffectsPass : public ModulePass<SideEffectsPass> {
-  void runOnModule() override {
-    auto module = getModule();
+struct SideEffectsPass : public OperationPass<SideEffectsPass, ModuleOp> {
+  void runOnOperation() override {
+    auto module = getOperation();
 
     // Walk operations detecting side effects.
     SmallVector<MemoryEffects::EffectInstance, 8> effects;
index 6082cdc..c39615e 100644 (file)
@@ -15,7 +15,7 @@ using namespace mlir;
 namespace {
 /// This is a symbol test pass that tests the symbol uselist functionality
 /// provided by the symbol table along with erasing from the symbol table.
-struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
+struct SymbolUsesPass : public OperationPass<SymbolUsesPass, ModuleOp> {
   WalkResult operateOnSymbol(Operation *symbol, ModuleOp module,
                              SmallVectorImpl<FuncOp> &deadFunctions) {
     // Test computing uses on a non symboltable op.
@@ -59,8 +59,8 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
     return WalkResult::advance();
   }
 
-  void runOnModule() override {
-    auto module = getModule();
+  void runOnOperation() override {
+    auto module = getOperation();
 
     // Walk nested symbols.
     SmallVector<FuncOp, 4> deadFunctions;
@@ -86,9 +86,10 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
 
 /// This is a symbol test pass that tests the symbol use replacement
 /// functionality provided by the symbol table.
-struct SymbolReplacementPass : public ModulePass<SymbolReplacementPass> {
-  void runOnModule() override {
-    auto module = getModule();
+struct SymbolReplacementPass
+    : public OperationPass<SymbolReplacementPass, ModuleOp> {
+  void runOnOperation() override {
+    auto module = getOperation();
 
     // Walk nested functions and modules.
     module.getBodyRegion().walk([&](Operation *nestedOp) {
index 95bef9b..be8a747 100644 (file)
@@ -13,8 +13,8 @@
 using namespace mlir;
 
 namespace {
-struct TestModulePass : public ModulePass<TestModulePass> {
-  void runOnModule() final {}
+struct TestModulePass : public OperationPass<TestModulePass, ModuleOp> {
+  void runOnOperation() final {}
 };
 struct TestFunctionPass : public FunctionPass<TestFunctionPass> {
   void runOnFunction() final {}
index 508f708..6455dab 100644 (file)
@@ -18,11 +18,11 @@ using namespace mlir;
 
 namespace {
 struct TestAllReduceLoweringPass
-    : public ModulePass<TestAllReduceLoweringPass> {
-  void runOnModule() override {
+    : public OperationPass<TestAllReduceLoweringPass, ModuleOp> {
+  void runOnOperation() override {
     OwningRewritePatternList patterns;
     populateGpuRewritePatterns(&getContext(), patterns);
-    applyPatternsGreedily(getModule(), patterns);
+    applyPatternsGreedily(getOperation(), patterns);
   }
 };
 } // namespace
index 89c25da..a181d64 100644 (file)
@@ -17,9 +17,9 @@
 using namespace mlir;
 
 namespace {
-struct TestCallGraphPass : public ModulePass<TestCallGraphPass> {
-  void runOnModule() {
-    llvm::errs() << "Testing : " << getModule().getAttr("test.name") << "\n";
+struct TestCallGraphPass : public OperationPass<TestCallGraphPass, ModuleOp> {
+  void runOnOperation() override {
+    llvm::errs() << "Testing : " << getOperation().getAttr("test.name") << "\n";
     getAnalysis<CallGraph>().print(llvm::errs());
   }
 };
index baae529..47152c4 100644 (file)
@@ -17,7 +17,7 @@ namespace {
 /// It also takes all operations that are not function operations or
 /// terminators and clones them with opaque locations which store the initial
 /// locations.
-struct TestOpaqueLoc : public ModulePass<TestOpaqueLoc> {
+struct TestOpaqueLoc : public OperationPass<TestOpaqueLoc, ModuleOp> {
 
   /// A simple structure which is used for testing as an underlying location in
   /// OpaqueLoc.
@@ -29,11 +29,11 @@ struct TestOpaqueLoc : public ModulePass<TestOpaqueLoc> {
     int id;
   };
 
-  void runOnModule() override {
+  void runOnOperation() override {
     std::vector<std::unique_ptr<MyLocation>> myLocs;
     int last_it = 0;
 
-    getModule().walk([&](Operation *op) {
+    getOperation().walk([&](Operation *op) {
       myLocs.push_back(std::make_unique<MyLocation>(last_it++));
 
       Location loc = op->getLoc();
@@ -74,7 +74,7 @@ struct TestOpaqueLoc : public ModulePass<TestOpaqueLoc> {
       os.flush();
     });
 
-    getModule().walk([&](Operation *op) { op->emitOpError(); });
+    getOperation().walk([&](Operation *op) { op->emitOpError(); });
   }
 };