Rename findFunction from the ML side of the house to be named getFunction(),
authorChris Lattner <clattner@google.com>
Thu, 27 Dec 2018 05:13:45 +0000 (21:13 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 21:38:49 +0000 (14:38 -0700)
making it more similar to the CFG side of things.  It is true that in a deeply
nested case that this is not a guaranteed O(1) time operation, and that 'get'
could lead compiler hackers to think this is cheap, but we need to merge these
and we can look into solutions for this in the future if it becomes a problem
in practice.

This is step 9/n towards merging instructions and statements, NFC.

PiperOrigin-RevId: 226983931

13 files changed:
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/Statement.h
mlir/include/mlir/IR/Statements.h
mlir/include/mlir/IR/StmtBlock.h
mlir/lib/Analysis/Utils.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/BuiltinOps.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/SSAValue.cpp
mlir/lib/IR/Statement.cpp
mlir/lib/IR/StmtBlock.cpp
mlir/lib/Transforms/DmaGeneration.cpp
mlir/lib/Transforms/Utils/LoopUtils.cpp

index fe6e1fffff107088f2f499cd956d94e0e3efa101..532564defb23d36e82f7e64cab90eb3991adb0a2 100644 (file)
@@ -289,20 +289,20 @@ public:
   /// Create ML function builder and set insertion point to the given statement,
   /// which will cause subsequent insertions to go right before it.
   MLFuncBuilder(Statement *stmt)
-      // TODO: Eliminate findFunction from this.
-      : MLFuncBuilder(stmt->findFunction()) {
+      // TODO: Eliminate getFunction from this.
+      : MLFuncBuilder(stmt->getFunction()) {
     setInsertionPoint(stmt);
   }
 
   MLFuncBuilder(StmtBlock *block)
-      // TODO: Eliminate findFunction from this.
-      : MLFuncBuilder(block->findFunction()) {
+      // TODO: Eliminate getFunction from this.
+      : MLFuncBuilder(block->getFunction()) {
     setInsertionPoint(block, block->end());
   }
 
   MLFuncBuilder(StmtBlock *block, StmtBlock::iterator insertPoint)
-      // TODO: Eliminate findFunction from this.
-      : MLFuncBuilder(block->findFunction()) {
+      // TODO: Eliminate getFunction from this.
+      : MLFuncBuilder(block->getFunction()) {
     setInsertionPoint(block, insertPoint);
   }
 
index d6576d32486ee4763e7ded53c21ebc22c3e736a8..188002a646ca2414344e738286634834497942a3 100644 (file)
@@ -104,7 +104,7 @@ public:
   /// Returns the function that this statement is part of.
   /// The function is determined by traversing the chain of parent statements.
   /// Returns nullptr if the statement is unlinked.
-  MLFunction *findFunction() const;
+  MLFunction *getFunction() const;
 
   /// Destroys this statement and its subclass data.
   void destroy();
index 0e14e0ebf275ddf115833af9188da5a111b8183c..d653f59a26196739815b76eb8dadfaed0e18188a 100644 (file)
@@ -290,7 +290,7 @@ public:
   }
 
   /// Resolve base class ambiguity.
-  using Statement::findFunction;
+  using Statement::getFunction;
 
   /// Operand iterators.
   using operand_iterator = OperandIterator<ForStmt, MLValue>;
index feb1aa6665bd97d4b8d1b848cd0ed15b7c647ccc..fe4d5f417cce08bad4ea36d73d1a4d5a30253556 100644 (file)
@@ -62,11 +62,11 @@ public:
     return const_cast<StmtBlock *>(this)->getContainingStmt();
   }
 
-  /// Returns the function that this statement block is part of.
-  /// The function is determined by traversing the chain of parent statements.
-  MLFunction *findFunction();
-  const MLFunction *findFunction() const {
-    return const_cast<StmtBlock *>(this)->findFunction();
+  /// Returns the function that this statement block is part of.  The function
+  /// is determined by traversing the chain of parent statements.
+  MLFunction *getFunction();
+  const MLFunction *getFunction() const {
+    return const_cast<StmtBlock *>(this)->getFunction();
   }
 
   //===--------------------------------------------------------------------===//
index 2428265acdb8a2f4438e003264a59af3d574977a..0c6cfea7ccd31d6f8f455e6dfd92f999732d4d0c 100644 (file)
@@ -39,7 +39,7 @@ bool mlir::properlyDominates(const Statement &a, const Statement &b) {
   if (&a == &b)
     return false;
 
-  if (a.findFunction() != b.findFunction())
+  if (a.getFunction() != b.getFunction())
     return false;
 
   if (a.getBlock() == b.getBlock()) {
index ec610e8cea7b8e971bb1c0111364c1e51d24d8fd..fa04b1a3a85fb57b1617dca2b7cd13526797ae2c 100644 (file)
@@ -1015,7 +1015,7 @@ protected:
       case SSAValueKind::BlockArgument:
         // If this is an argument to the function, give it an 'arg' name.
         if (auto *block = cast<BlockArgument>(value)->getOwner())
-          if (auto *fn = block->findFunction())
+          if (auto *fn = block->getFunction())
             if (&fn->getBlockList().front() == block) {
               specialName << "arg" << nextArgumentID++;
               break;
@@ -1639,7 +1639,7 @@ void BasicBlock::printAsOperand(raw_ostream &os, bool printType) {
 }
 
 void Statement::print(raw_ostream &os) const {
-  MLFunction *function = findFunction();
+  MLFunction *function = getFunction();
   if (!function) {
     os << "<<UNLINKED STATEMENT>>\n";
     return;
@@ -1653,7 +1653,7 @@ void Statement::print(raw_ostream &os) const {
 void Statement::dump() const { print(llvm::errs()); }
 
 void StmtBlock::printBlock(raw_ostream &os) const {
-  const MLFunction *function = findFunction();
+  const MLFunction *function = getFunction();
   ModuleState state(function->getContext());
   ModulePrinter modulePrinter(os, state);
   MLFunctionPrinter(function, modulePrinter).print(this);
index dfd59c4d3800ee3a3f30a3a4546190a4a5ebe6f4..04d032dc2ebcdb938eab053cdc9d033279ee376b 100644 (file)
@@ -474,7 +474,7 @@ void ReturnOp::print(OpAsmPrinter *p) const {
 bool ReturnOp::verify() const {
   const Function *function;
   if (auto *stmt = dyn_cast<OperationStmt>(getOperation()))
-    function = stmt->getBlock()->findFunction();
+    function = stmt->getFunction();
   else
     function = cast<Instruction>(getOperation())->getFunction();
 
index c946a76a98bea4f6da03358c54a4617a3cbeb4ca..d50e7070f7089eab0ae9558a1c6806ed0a7db76b 100644 (file)
@@ -99,7 +99,7 @@ void Operation::setLoc(Location loc) {
 Function *Operation::getOperationFunction() {
   if (auto *inst = llvm::dyn_cast<Instruction>(this))
     return inst->getFunction();
-  return llvm::cast<OperationStmt>(this)->findFunction();
+  return llvm::cast<OperationStmt>(this)->getFunction();
 }
 
 /// Return the number of operands this operation has.
index 375e1057b01593a2d5843c7b5be4ed114a5f069b..32365d67f349f128cd3c21f4612330adb536fdc1 100644 (file)
@@ -57,9 +57,9 @@ Function *SSAValue::getFunction() {
   case SSAValueKind::BlockArgument:
     return cast<BlockArgument>(this)->getFunction();
   case SSAValueKind::StmtResult:
-    return getDefiningStmt()->findFunction();
+    return getDefiningStmt()->getFunction();
   case SSAValueKind::ForStmt:
-    return cast<ForStmt>(this)->findFunction();
+    return cast<ForStmt>(this)->getFunction();
   }
 }
 
@@ -121,6 +121,6 @@ MLFunction *MLValue::getFunction() {
 /// Return the function that this argument is defined in.
 MLFunction *BlockArgument::getFunction() {
   if (auto *owner = getOwner())
-    return owner->findFunction();
+    return owner->getFunction();
   return nullptr;
 }
index f072cc145614f7be16ba686bf9772ba69bf354b8..8922aaf72e0866971b7e4d59324574ddfbddc217 100644 (file)
@@ -81,8 +81,8 @@ Statement *Statement::getParentStmt() const {
   return block ? block->getContainingStmt() : nullptr;
 }
 
-MLFunction *Statement::findFunction() const {
-  return block ? block->findFunction() : nullptr;
+MLFunction *Statement::getFunction() const {
+  return block ? block->getFunction() : nullptr;
 }
 
 MLValue *Statement::getOperand(unsigned idx) {
@@ -368,7 +368,7 @@ MLIRContext *OperationStmt::getContext() const {
 
   // In the very odd case where we have no operands or results, fall back to
   // doing a find.
-  return findFunction()->getContext();
+  return getFunction()->getContext();
 }
 
 bool OperationStmt::isReturn() const { return isa<ReturnOp>(); }
@@ -560,7 +560,7 @@ MLIRContext *IfStmt::getContext() const {
   // Check for degenerate case of if statement with no operands.
   // This is unlikely, but legal.
   if (operands.empty())
-    return findFunction()->getContext();
+    return getFunction()->getContext();
 
   return getOperand(0)->getType().getContext();
 }
index 0716d63c6ef8ba4e78828d18f61ac62a858de3f6..1c2c77d2da3b61e992a50a328a47ac60c82f9bc3 100644 (file)
@@ -32,7 +32,7 @@ Statement *StmtBlock::getContainingStmt() {
   return parent ? parent->getContainingStmt() : nullptr;
 }
 
-MLFunction *StmtBlock::findFunction() {
+MLFunction *StmtBlock::getFunction() {
   StmtBlock *block = this;
   while (auto *stmt = block->getContainingStmt()) {
     block = stmt->getBlock();
index a927516345a08a7be099ee25453f6cbed926e3d6..62cf55e37d9064e40d295fce751ab1f287308555 100644 (file)
@@ -180,7 +180,7 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
   MLFuncBuilder *b = region.isWrite() ? &epilogue : &prologue;
 
   // Builder to create constants at the top level.
-  MLFuncBuilder top(forStmt->findFunction());
+  MLFuncBuilder top(forStmt->getFunction());
 
   auto loc = forStmt->getLoc();
   auto *memref = region.memref;
index 023d3ebc6437c7616a6596b332c9bf2eb5de6c46..5a5617f3fb1a9dd4db75dd8117e5fe115e4e657d 100644 (file)
@@ -101,7 +101,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
   // Replaces all IV uses to its single iteration value.
   if (!forStmt->use_empty()) {
     if (forStmt->hasConstantLowerBound()) {
-      auto *mlFunc = forStmt->findFunction();
+      auto *mlFunc = forStmt->getFunction();
       MLFuncBuilder topBuilder(&mlFunc->getBody()->front());
       auto constOp = topBuilder.create<ConstantIndexOp>(
           forStmt->getLoc(), forStmt->getConstantLowerBound());