NFC: Remove `Module::getFunctions` in favor of a general `getOps<T>`.
authorRiver Riddle <riverriddle@google.com>
Tue, 9 Jul 2019 01:27:45 +0000 (18:27 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 9 Jul 2019 01:28:17 +0000 (18:28 -0700)
Modules can now contain more than just Functions, this just updates the iteration API to reflect that. The 'begin'/'end' methods have also been updated to iterate over opaque Operations.

PiperOrigin-RevId: 257099084

20 files changed:
mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
mlir/examples/toy/Ch5/mlir/LateLowering.cpp
mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
mlir/include/mlir/IR/Module.h
mlir/lib/Analysis/OpStats.cpp
mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp
mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/GPU/Transforms/KernelOutlining.cpp
mlir/lib/IR/Module.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/Pass/Pass.cpp
mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp
mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp
mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/lib/Transforms/DialectConversion.cpp

index a01a7fd79a39e901411d63c72c247a8670885b29..3cef8f3f50d4124239517f9f57f75230055f5367 100644 (file)
@@ -148,7 +148,7 @@ static void populateLinalg3ToLLVMConversionPatterns(
 
 void linalg::convertLinalg3ToLLVM(Module module) {
   // Remove affine constructs.
-  for (auto func : module) {
+  for (auto func : module.getOps<FuncOp>()) {
     auto rr = lowerAffineConstructs(func);
     (void)rr;
     assert(succeeded(rr) && "affine loop lowering failed");
index 317532f1fc395fcc5227b7b255afff34a97a13c6..fe24f0fcd3e21265ae7556dc3d7b535559af51a4 100644 (file)
@@ -140,7 +140,8 @@ public:
 
     // Delete any generic function left
     // FIXME: we may want this as a separate pass.
-    for (mlir::Function function : llvm::make_early_inc_range(module)) {
+    for (mlir::Function function :
+         llvm::make_early_inc_range(module.getOps<mlir::Function>())) {
       if (auto genericAttr =
               function.getAttrOfType<mlir::BoolAttr>("toy.generic")) {
         if (genericAttr.getValue())
index 2ea1c6cad9a0d822693510cd0e816135f12906ff..5267ae3d5db88e3c19f1dd07769795a393fce210 100644 (file)
@@ -369,7 +369,7 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
     // affine dialect: they already include conversion to the LLVM dialect.
 
     // First patch calls type to return memref instead of ToyArray
-    for (auto function : getModule()) {
+    for (auto function : getModule().getOps<FuncOp>()) {
       function.walk([&](Operation *op) {
         auto callOp = dyn_cast<CallOp>(op);
         if (!callOp)
@@ -384,7 +384,7 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
       });
     }
 
-    for (auto function : getModule()) {
+    for (auto function : getModule().getOps<FuncOp>()) {
       function.walk([&](Operation *op) {
         // Turns toy.alloc into sequence of alloc/dealloc (later malloc/free).
         if (auto allocOp = dyn_cast<toy::AllocOp>(op)) {
@@ -403,7 +403,7 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
     }
 
     // Lower Linalg to affine
-    for (auto function : getModule())
+    for (auto function : getModule().getOps<FuncOp>())
       linalg::lowerToLoops(function);
 
     getModule().dump();
index 2adb4cd6f51e3ecf157d918f973d9dbee58d6a4c..a94b261b30eafab2448fc11d0f3403abc5bb3a27 100644 (file)
@@ -141,7 +141,8 @@ public:
 
     // Delete any generic function left
     // FIXME: we may want this as a separate pass.
-    for (mlir::Function function : llvm::make_early_inc_range(module)) {
+    for (mlir::Function function :
+         llvm::make_early_inc_range(module.getOps<mlir::Function>())) {
       if (auto genericAttr =
               function.getAttrOfType<mlir::BoolAttr>("toy.generic")) {
         if (genericAttr.getValue())
index 683a35fafee622ede135937a868c436c32b653ae..2ad1151def698c3e933fec81df17ba9251ca495c 100644 (file)
@@ -64,18 +64,17 @@ public:
   // Body Management.
   //===--------------------------------------------------------------------===//
 
-  // Iterate over the functions within the module.
-  using iterator = Block::op_iterator<FuncOp>;
-
-  // Iteration over the functions in the module.
-  iterator begin() { return getBody()->op_begin<FuncOp>(); }
-  iterator end() { return getBody()->op_end<FuncOp>(); }
-  Function front() { return *begin(); }
-  Function back() { return *std::prev(end()); }
-
-  /// This is the list of functions in the module.
-  llvm::iterator_range<iterator> getFunctions() {
-    return getBody()->getOps<FuncOp>();
+  /// Iteration over the operations in the module.
+  using iterator = Block::iterator;
+
+  iterator begin() { return getBody()->begin(); }
+  iterator end() { return getBody()->end(); }
+  Operation &front() { return *begin(); }
+
+  /// This returns a range of operations of the given type 'T' held within the
+  /// module.
+  template <typename T> llvm::iterator_range<Block::op_iterator<T>> getOps() {
+    return getBody()->getOps<T>();
   }
 
   /// Insert the operation into the back of the body, before the terminator.
@@ -83,7 +82,7 @@ public:
     insert(Block::iterator(getBody()->getTerminator()), op);
   }
 
-  /// Inser the operation at the given insertion point. Note: The operation is
+  /// Insert the operation at the given insertion point. Note: The operation is
   /// never inserted after the terminator, even if the insertion point is end().
   void insert(Operation *insertPt, Operation *op) {
     insert(Block::iterator(insertPt), op);
@@ -106,7 +105,7 @@ public:
   /// name exists. Function names never include the @ on them. Note: This
   /// performs a linear scan of held symbols.
   Function getNamedFunction(Identifier name) {
-    auto it = llvm::find_if(getFunctions(),
+    auto it = llvm::find_if(getOps<FuncOp>(),
                             [name](FuncOp fn) { return fn.getName() == name; });
     return it == end() ? nullptr : *it;
   }
index 75a2fc1a5dcc03f3cd4daf383309a85063a2e663..f01ec56ddb13a7ea398f1fa9afc899068363c0c9 100644 (file)
@@ -45,8 +45,8 @@ void PrintOpStatsPass::runOnModule() {
   opCount.clear();
 
   // Compute the operation statistics for each function in the module.
-  for (auto fn : getModule())
-    fn.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
+  for (auto &op : getModule())
+    op.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
   printSummary();
 }
 
index da896afb090825f78123818533f7a4bb7d4a53a1..1dbedf9fcee9335f138ad3c6a534b6344cfe3f76 100644 (file)
@@ -64,7 +64,7 @@ public:
     LLVMInitializeNVPTXTargetMC();
     LLVMInitializeNVPTXAsmPrinter();
 
-    for (auto function : getModule()) {
+    for (auto function : getModule().getOps<FuncOp>()) {
       if (!gpu::GPUDialect::isKernel(function) || function.isExternal()) {
         continue;
       }
index 0759324c77c4b0d91083550c4ac9da2a4bf05ecf..dafc5fa573037363857a6422440cf7ae809b17ce 100644 (file)
@@ -130,7 +130,7 @@ public:
     // Cache the used LLVM types.
     initializeCachedTypes();
 
-    for (auto func : getModule()) {
+    for (auto func : getModule().getOps<FuncOp>()) {
       func.walk<mlir::gpu::LaunchFuncOp>(
           [this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); });
     }
index 6de304333899f86217dda3e3297b520283f8b288..6306c567907c38df330bd8b5ede5fbf810130be9 100644 (file)
@@ -123,7 +123,7 @@ public:
     auto module = getModule();
     Builder builder(&getContext());
 
-    auto functions = module.getFunctions();
+    auto functions = module.getOps<FuncOp>();
     for (auto it = functions.begin(); it != functions.end();) {
       // Move iterator to after the current function so that potential insertion
       // of the accessor is after the kernel with cubin iself.
index 2a52706c2770a854a324efeb9a416fea23a140b9..01d473e7f59235afda3f567928d906c6a7557925 100644 (file)
@@ -938,7 +938,7 @@ static void ensureDistinctSuccessors(Block &bb) {
 }
 
 void mlir::LLVM::ensureDistinctSuccessors(Module m) {
-  for (auto f : m) {
+  for (auto f : m.getOps<FuncOp>()) {
     for (auto &bb : f.getBlocks()) {
       ::ensureDistinctSuccessors(bb);
     }
index 4f110ac286aea7eb4e4108731504e73e3999246b..0bc7041bd6e8bbcb5488b693c45037ef38648343 100644 (file)
@@ -98,7 +98,7 @@ class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
 public:
   void runOnModule() override {
     ModuleManager moduleManager(getModule());
-    for (auto func : getModule()) {
+    for (auto func : getModule().getOps<FuncOp>()) {
       func.walk<mlir::gpu::LaunchOp>([&](mlir::gpu::LaunchOp op) {
         Function outlinedFunc = outlineKernelFunc(op);
         moduleManager.insert(outlinedFunc);
index c9955c2895d381a326077339026ac6a898f0c66f..4b9a1a6c9bc68c18c52626adc583d9b3581ee332 100644 (file)
@@ -107,7 +107,7 @@ LogicalResult ModuleOp::verify() {
 
   // Check that all functions are uniquely named.
   llvm::StringMap<Location> nameToOrigLoc;
-  for (auto fn : getFunctions()) {
+  for (auto fn : getOps<FuncOp>()) {
     auto it = nameToOrigLoc.try_emplace(fn.getName(), fn.getLoc());
     if (!it.second)
       return fn.emitError()
index 08df435a73d08900fa8f6c3a540a3e625eeff6a1..7da8f442a06c4037ac073ba89d9b789f4245705f 100644 (file)
@@ -22,7 +22,7 @@ using namespace mlir;
 
 /// Build a symbol table with the symbols within the given module.
 SymbolTable::SymbolTable(ModuleOp module) : context(module.getContext()) {
-  for (auto func : module) {
+  for (auto func : module.getOps<FuncOp>()) {
     auto inserted = symbolTable.insert({func.getName(), func});
     (void)inserted;
     assert(inserted.second &&
index a3d89c3c42bb5d98969d7fcaf2e326d4d6487274..2b9c893276a090437db7e672d09e5a268b6b8334 100644 (file)
@@ -805,7 +805,7 @@ static void lowerLinalgForToCFG(Function &f) {
 void LowerLinalgToLLVMPass::runOnModule() {
   auto module = getModule();
 
-  for (auto f : module.getFunctions()) {
+  for (auto f : module.getOps<FuncOp>()) {
     lowerLinalgSubViewOps(f);
     lowerLinalgForToCFG(f);
     if (failed(lowerAffineConstructs(f)))
index 2076dea88f9b72655947ec771947aac556953e3f..a35efbdcb7a2a16108bcfa2f8c5b6cc8755b4c12 100644 (file)
@@ -153,7 +153,7 @@ static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe,
 /// module.
 void ModuleToFunctionPassAdaptor::runOnModule() {
   ModuleAnalysisManager &mam = getAnalysisManager();
-  for (auto func : getModule()) {
+  for (auto func : getModule().getOps<FuncOp>()) {
     // Skip external functions.
     if (func.isExternal())
       continue;
@@ -185,7 +185,7 @@ void ModuleToFunctionPassAdaptorParallel::runOnModule() {
   // This ensures that an analysis manager exists for each function, as well as
   // providing a queue of functions to execute over.
   std::vector<std::pair<Function, FunctionAnalysisManager>> funcAMPairs;
-  for (auto func : getModule())
+  for (auto func : getModule().getOps<FuncOp>())
     if (!func.isExternal())
       funcAMPairs.emplace_back(func, mam.slice(func));
 
index 169fec3b39af14b13447f23e800d2100bb4c7c30..765a36e791a8fb65760854b23e52afb000d272d4 100644 (file)
@@ -129,7 +129,7 @@ void InferQuantizedTypesPass::runOnModule() {
 void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext,
                                             const TargetConfiguration &config) {
   CAGSlice cag(solverContext);
-  for (auto f : getModule()) {
+  for (auto f : getModule().getOps<FuncOp>()) {
     f.walk([&cag, &config](Operation *op) { config.handleOp(op, cag); });
   }
   config.finalizeAnchors(cag);
index e7ee2371b036705150cebb82b7a70f2e465bd7a1..78dbccf1728ee1360baa815fd385f20111b89c74 100644 (file)
@@ -45,7 +45,7 @@ LogicalResult serializeModule(Module module, StringRef outputFilename) {
   // wrapping the SPIR-V ModuleOp inside a MLIR module. This should be changed
   // to take in the SPIR-V ModuleOp directly after module and function are
   // migrated to be general ops.
-  for (auto fn : module) {
+  for (auto fn : module.getOps<FuncOp>()) {
     fn.walk<spirv::ModuleOp>([&](spirv::ModuleOp spirvModule) {
       if (done) {
         spirvModule.emitError("found more than one 'spv.module' op");
index b29db1a1ce7fc98893fd0309ba29fb8f58d76563..0fef76068124472d6b247f125f02db9485d5d75f 100644 (file)
@@ -68,7 +68,7 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module m) {
 
   // Insert the nvvm.annotations kernel so that the NVVM backend recognizes the
   // function as a kernel.
-  for (Function func : m) {
+  for (Function func : m.getOps<FuncOp>()) {
     if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
       continue;
 
index df2bd57d5f0fd5d5918458b7a67b502c63183422..a358e8363f44fc61931b4bb2cbf40a69d86e5695 100644 (file)
@@ -375,7 +375,7 @@ bool ModuleTranslation::convertOneFunction(Function func) {
 bool ModuleTranslation::convertFunctions() {
   // Declare all functions first because there may be function calls that form a
   // call graph with cycles.
-  for (Function function : mlirModule) {
+  for (Function function : mlirModule.getOps<FuncOp>()) {
     mlir::BoolAttr isVarArgsAttr =
         function.getAttrOfType<BoolAttr>("std.varargs");
     bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue();
@@ -392,7 +392,7 @@ bool ModuleTranslation::convertFunctions() {
   }
 
   // Convert functions.
-  for (Function function : mlirModule) {
+  for (Function function : mlirModule.getOps<FuncOp>()) {
     // Ignore external functions.
     if (function.isExternal())
       continue;
index 9375a7b8445ec067d1c90443b80d569f90b993ad..42683afc4684596ac4673fac1ea7e8ef0759a9bf 100644 (file)
@@ -1131,7 +1131,7 @@ LogicalResult
 mlir::applyConversionPatterns(Module module, ConversionTarget &target,
                               TypeConverter &converter,
                               OwningRewritePatternList &&patterns) {
-  SmallVector<Function, 32> allFunctions(module.getFunctions());
+  SmallVector<Function, 32> allFunctions(module.getOps<FuncOp>());
   return applyConversionPatterns(allFunctions, target, converter,
                                  std::move(patterns));
 }