Eliminate the Instruction, BasicBlock, CFGFunction, MLFunction, and ExtFunction class...
authorChris Lattner <clattner@google.com>
Thu, 27 Dec 2018 19:07:34 +0000 (11:07 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 21:39:49 +0000 (14:39 -0700)
This *only* changes the internal data structures, it does not affect the user visible syntax or structure of MLIR code.  Function gets new "isCFG()" sorts of predicates as a transitional measure.

This patch is gross in a number of ways, largely in an effort to reduce the amount of mechanical churn in one go.  It introduces a bunch of using decls to keep the old names alive for now, and a bunch of stuff needs to be renamed.

This is step 10/n towards merging instructions and statements, NFC.

PiperOrigin-RevId: 227044402

47 files changed:
mlir/include/mlir/Analysis/Dominance.h
mlir/include/mlir/IR/BasicBlock.h [deleted file]
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/BuiltinOps.h
mlir/include/mlir/IR/CFGFunction.h [deleted file]
mlir/include/mlir/IR/CFGValue.h [deleted file]
mlir/include/mlir/IR/Function.h
mlir/include/mlir/IR/FunctionGraphTraits.h
mlir/include/mlir/IR/Instructions.h [deleted file]
mlir/include/mlir/IR/MLFunction.h [deleted file]
mlir/include/mlir/IR/MLValue.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/IR/SSAValue.h
mlir/include/mlir/IR/Statement.h
mlir/include/mlir/IR/Statements.h
mlir/include/mlir/IR/StmtBlock.h
mlir/include/mlir/IR/StmtVisitor.h
mlir/include/mlir/Pass.h
mlir/include/mlir/Transforms/CFGFunctionViewGraph.h
mlir/include/mlir/Transforms/LoopUtils.h
mlir/include/mlir/Transforms/Utils.h
mlir/lib/Analysis/Dominance.cpp
mlir/lib/Analysis/OpStats.cpp
mlir/lib/Analysis/Pass.cpp
mlir/lib/Analysis/Verifier.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/BasicBlock.cpp [deleted file]
mlir/lib/IR/Builders.cpp
mlir/lib/IR/BuiltinOps.cpp
mlir/lib/IR/Function.cpp
mlir/lib/IR/Instructions.cpp [deleted file]
mlir/lib/IR/Operation.cpp
mlir/lib/IR/SSAValue.cpp
mlir/lib/IR/Statement.cpp
mlir/lib/IR/StmtBlock.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
mlir/lib/Transforms/CSE.cpp
mlir/lib/Transforms/ConstantFold.cpp
mlir/lib/Transforms/ConvertToCFG.cpp
mlir/lib/Transforms/LowerAffineApply.cpp
mlir/lib/Transforms/SimplifyAffineExpr.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/lib/Transforms/Utils/Utils.cpp
mlir/tools/mlir-opt/mlir-opt.cpp

index 22cd7daf6249525983be7134fa95b84821972753..4ec61869d2f824b33c948259cbf05f8fb0c434d9 100644 (file)
@@ -62,7 +62,7 @@ public:
 
   /// Return true if instruction A dominates instruction B.
   bool dominates(const SSAValue *a, const Instruction *b) {
-    return a->getDefiningInst() == b || properlyDominates(a, b);
+    return (Instruction *)a->getDefiningInst() == b || properlyDominates(a, b);
   }
 
   // dominates/properlyDominates for basic blocks.
diff --git a/mlir/include/mlir/IR/BasicBlock.h b/mlir/include/mlir/IR/BasicBlock.h
deleted file mode 100644 (file)
index 1b38a5e..0000000
+++ /dev/null
@@ -1,381 +0,0 @@
-//===- BasicBlock.h - MLIR BasicBlock Class ---------------------*- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#ifndef MLIR_IR_BASICBLOCK_H
-#define MLIR_IR_BASICBLOCK_H
-
-#include "mlir/IR/Instructions.h"
-
-namespace mlir {
-class BBArgument;
-class CFGFunction;
-template <typename BlockType> class PredecessorIterator;
-template <typename BlockType> class SuccessorIterator;
-
-/// Each basic block in a CFG function contains a list of basic block arguments,
-/// normal instructions, and a terminator instruction.
-///
-/// Basic blocks form a graph (the CFG) which can be traversed through
-/// predecessor and successor edges.
-class BasicBlock
-    : public IRObjectWithUseList,
-      public llvm::ilist_node_with_parent<BasicBlock, CFGFunction> {
-public:
-  explicit BasicBlock();
-  ~BasicBlock();
-
-  /// Return the function that a BasicBlock is part of.
-  CFGFunction *getFunction() { return function; }
-  const CFGFunction *getFunction() const { return function; }
-
-  /// Return the function that a BasicBlock is part of.
-  const CFGFunction *getParent() const { return function; }
-  CFGFunction *getParent() { return function; }
-
-  //===--------------------------------------------------------------------===//
-  // Block argument management
-  //===--------------------------------------------------------------------===//
-
-  // This is the list of arguments to the block.
-  using BBArgListType = ArrayRef<BBArgument *>;
-  BBArgListType getArguments() const { return arguments; }
-
-  using args_iterator = BBArgListType::iterator;
-  using reverse_args_iterator = BBArgListType::reverse_iterator;
-  args_iterator args_begin() const { return getArguments().begin(); }
-  args_iterator args_end() const { return getArguments().end(); }
-  reverse_args_iterator args_rbegin() const { return getArguments().rbegin(); }
-  reverse_args_iterator args_rend() const { return getArguments().rend(); }
-
-  bool args_empty() const { return arguments.empty(); }
-
-  /// Add one value to the argument list.
-  BBArgument *addArgument(Type type);
-
-  /// Add one argument to the argument list for each type specified in the list.
-  llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type> types);
-
-  /// Erase the argument at 'index' and remove it from the argument list.
-  void eraseArgument(unsigned index);
-
-  unsigned getNumArguments() const { return arguments.size(); }
-  BBArgument *getArgument(unsigned i) { return arguments[i]; }
-  const BBArgument *getArgument(unsigned i) const { return arguments[i]; }
-
-  //===--------------------------------------------------------------------===//
-  // Operation list management
-  //===--------------------------------------------------------------------===//
-
-  /// This is the list of operations in the block.
-  using OperationListType = llvm::iplist<Instruction>;
-  OperationListType &getOperations() { return operations; }
-  const OperationListType &getOperations() const { return operations; }
-
-  // Iteration over the operations in the block.
-  using iterator = OperationListType::iterator;
-  using const_iterator = OperationListType::const_iterator;
-  using reverse_iterator = OperationListType::reverse_iterator;
-  using const_reverse_iterator = OperationListType::const_reverse_iterator;
-
-  iterator begin() { return operations.begin(); }
-  iterator end() { return operations.end(); }
-  const_iterator begin() const { return operations.begin(); }
-  const_iterator end() const { return operations.end(); }
-  reverse_iterator rbegin() { return operations.rbegin(); }
-  reverse_iterator rend() { return operations.rend(); }
-  const_reverse_iterator rbegin() const { return operations.rbegin(); }
-  const_reverse_iterator rend() const { return operations.rend(); }
-
-  bool empty() const { return operations.empty(); }
-  void push_back(Instruction *inst) { operations.push_back(inst); }
-  void push_front(Instruction *inst) { operations.push_front(inst); }
-
-  Instruction &back() { return operations.back(); }
-  const Instruction &back() const {
-    return const_cast<BasicBlock *>(this)->back();
-  }
-
-  Instruction &front() { return operations.front(); }
-  const Instruction &front() const {
-    return const_cast<BasicBlock*>(this)->front();
-  }
-
-  //===--------------------------------------------------------------------===//
-  // Terminator management
-  //===--------------------------------------------------------------------===//
-
-  /// Get the terminator instruction of this block, or null if the block is
-  /// malformed.
-  Instruction *getTerminator() const;
-
-  //===--------------------------------------------------------------------===//
-  // Predecessors and successors.
-  //===--------------------------------------------------------------------===//
-
-  // Predecessor iteration.
-  using const_pred_iterator = PredecessorIterator<const BasicBlock>;
-  const_pred_iterator pred_begin() const;
-  const_pred_iterator pred_end() const;
-  llvm::iterator_range<const_pred_iterator> getPredecessors() const;
-
-  using pred_iterator = PredecessorIterator<BasicBlock>;
-  pred_iterator pred_begin();
-  pred_iterator pred_end();
-  llvm::iterator_range<pred_iterator> getPredecessors();
-
-  /// Return true if this block has no predecessors.
-  bool hasNoPredecessors() const;
-
-  /// If this basic block has exactly one predecessor, return it.  Otherwise,
-  /// return null.
-  ///
-  /// Note that if a block has duplicate predecessors from a single block (e.g.
-  /// if you have a conditional branch with the same block as the true/false
-  /// destinations) is not considered to be a single predecessor.
-  BasicBlock *getSinglePredecessor();
-
-  const BasicBlock *getSinglePredecessor() const {
-    return const_cast<BasicBlock *>(this)->getSinglePredecessor();
-  }
-
-  // Indexed successor access.
-  unsigned getNumSuccessors() const {
-    return getTerminator()->getNumSuccessors();
-  }
-  const BasicBlock *getSuccessor(unsigned i) const {
-    return const_cast<BasicBlock *>(this)->getSuccessor(i);
-  }
-  BasicBlock *getSuccessor(unsigned i) {
-    return getTerminator()->getSuccessor(i);
-  }
-
-  // Successor iteration.
-  using const_succ_iterator = SuccessorIterator<const BasicBlock>;
-  const_succ_iterator succ_begin() const;
-  const_succ_iterator succ_end() const;
-  llvm::iterator_range<const_succ_iterator> getSuccessors() const;
-
-  using succ_iterator = SuccessorIterator<BasicBlock>;
-  succ_iterator succ_begin();
-  succ_iterator succ_end();
-  llvm::iterator_range<succ_iterator> getSuccessors();
-
-  //===--------------------------------------------------------------------===//
-  // Manipulators
-  //===--------------------------------------------------------------------===//
-
-  /// Unlink this BasicBlock from its CFGFunction and delete it.
-  void eraseFromFunction();
-
-  /// Split the basic block into two basic blocks before the specified
-  /// instruction or iterator.
-  ///
-  /// Note that all instructions BEFORE the specified iterator stay as part of
-  /// the original basic block, an unconditional branch is added to the original
-  /// block (going to the new block), and the rest of the instructions in the
-  /// original block are moved to the new BB, including the old terminator.  The
-  /// newly formed BasicBlock is returned.
-  ///
-  /// This function invalidates the specified iterator.
-  BasicBlock *splitBasicBlock(iterator splitBefore);
-  BasicBlock *splitBasicBlock(Instruction *splitBeforeInst) {
-    return splitBasicBlock(iterator(splitBeforeInst));
-  }
-
-  void print(raw_ostream &os) const;
-  void dump() const;
-
-  /// Print out the name of the basic block without printing its body.
-  /// NOTE: The printType argument is ignored.  We keep it for compatibility
-  /// with LLVM dominator machinery that expects it to exist.
-  void printAsOperand(raw_ostream &os, bool printType = true);
-
-  /// getSublistAccess() - Returns pointer to member of operation list
-  static OperationListType BasicBlock::*getSublistAccess(Instruction *) {
-    return &BasicBlock::operations;
-  }
-
-private:
-  CFGFunction *function = nullptr;
-
-  /// This is the list of operations in the block.
-  OperationListType operations;
-
-  /// This is the list of arguments to the block.
-  std::vector<BBArgument *> arguments;
-
-  BasicBlock(const BasicBlock&) = delete;
-  void operator=(const BasicBlock&) = delete;
-
-  friend struct llvm::ilist_traits<BasicBlock>;
-};
-
-//===----------------------------------------------------------------------===//
-// Predecessors
-//===----------------------------------------------------------------------===//
-
-/// Implement a predecessor iterator as a forward iterator.  This works by
-/// walking the use lists of basic blocks.  The entries on this list are the
-/// BasicBlockOperands that are embedded into terminator instructions.  From the
-/// operand, we can get the terminator that contains it, and it's parent block
-/// is the predecessor.
-template <typename BlockType>
-class PredecessorIterator
-    : public llvm::iterator_facade_base<PredecessorIterator<BlockType>,
-                                        std::forward_iterator_tag,
-                                        BlockType *> {
-public:
-  PredecessorIterator(BasicBlockOperand *firstOperand)
-      : bbUseIterator(firstOperand) {}
-
-  PredecessorIterator &operator=(const PredecessorIterator &rhs) {
-    bbUseIterator = rhs.bbUseIterator;
-  }
-
-  bool operator==(const PredecessorIterator &rhs) const {
-    return bbUseIterator == rhs.bbUseIterator;
-  }
-
-  BlockType *operator*() const {
-    // The use iterator points to an operand of a terminator.  The predecessor
-    // we return is the basic block that that terminator is embedded into.
-    return bbUseIterator.getUser()->getBlock();
-  }
-
-  PredecessorIterator &operator++() {
-    ++bbUseIterator;
-    return *this;
-  }
-
-  /// Get the successor number in the predecessor terminator.
-  unsigned getSuccessorIndex() const {
-    return bbUseIterator->getOperandNumber();
-  }
-
-private:
-  using BBUseIterator = SSAValueUseIterator<BasicBlockOperand, Instruction>;
-  BBUseIterator bbUseIterator;
-};
-
-inline auto BasicBlock::pred_begin() const -> const_pred_iterator {
-  return const_pred_iterator((BasicBlockOperand *)getFirstUse());
-}
-
-inline auto BasicBlock::pred_end() const -> const_pred_iterator {
-  return const_pred_iterator(nullptr);
-}
-
-inline auto BasicBlock::getPredecessors() const
-    -> llvm::iterator_range<const_pred_iterator> {
-  return {pred_begin(), pred_end()};
-}
-
-inline auto BasicBlock::pred_begin() -> pred_iterator {
-  return pred_iterator((BasicBlockOperand *)getFirstUse());
-}
-
-inline auto BasicBlock::pred_end() -> pred_iterator {
-  return pred_iterator(nullptr);
-}
-
-inline auto BasicBlock::getPredecessors()
-    -> llvm::iterator_range<pred_iterator> {
-  return {pred_begin(), pred_end()};
-}
-
-//===----------------------------------------------------------------------===//
-// Successors
-//===----------------------------------------------------------------------===//
-
-/// This template implments the successor iterators for basic block.
-template <typename BlockType>
-class SuccessorIterator final
-    : public IndexedAccessorIterator<SuccessorIterator<BlockType>, BlockType,
-                                     BlockType> {
-public:
-  /// Initializes the result iterator to the specified index.
-  SuccessorIterator(BlockType *object, unsigned index)
-      : IndexedAccessorIterator<SuccessorIterator<BlockType>, BlockType,
-                                BlockType>(object, index) {}
-
-  SuccessorIterator(const SuccessorIterator &other)
-      : SuccessorIterator(other.object, other.index) {}
-
-  /// Support converting to the const variant. This will be a no-op for const
-  /// variant.
-  operator SuccessorIterator<const BlockType>() const {
-    return SuccessorIterator<const BlockType>(this->object, this->index);
-  }
-
-  BlockType *operator*() const {
-    return this->object->getSuccessor(this->index);
-  }
-
-  /// Get the successor number in the terminator.
-  unsigned getSuccessorIndex() const { return this->index; }
-};
-
-inline auto BasicBlock::succ_begin() const -> const_succ_iterator {
-  return const_succ_iterator(this, 0);
-}
-
-inline auto BasicBlock::succ_end() const -> const_succ_iterator {
-  return const_succ_iterator(this, getNumSuccessors());
-}
-
-inline auto BasicBlock::getSuccessors() const
-    -> llvm::iterator_range<const_succ_iterator> {
-  return {succ_begin(), succ_end()};
-}
-
-inline auto BasicBlock::succ_begin() -> succ_iterator {
-  return succ_iterator(this, 0);
-}
-
-inline auto BasicBlock::succ_end() -> succ_iterator {
-  return succ_iterator(this, getNumSuccessors());
-}
-
-inline auto BasicBlock::getSuccessors() -> llvm::iterator_range<succ_iterator> {
-  return {succ_begin(), succ_end()};
-}
-
-} // end namespace mlir
-
-//===----------------------------------------------------------------------===//
-// ilist_traits for BasicBlock
-//===----------------------------------------------------------------------===//
-
-namespace llvm {
-
-template <>
-struct ilist_traits<::mlir::BasicBlock>
-  : public ilist_alloc_traits<::mlir::BasicBlock> {
-  using BasicBlock = ::mlir::BasicBlock;
-  using block_iterator = simple_ilist<BasicBlock>::iterator;
-
-  void addNodeToList(BasicBlock *block);
-  void removeNodeFromList(BasicBlock *block);
-  void transferNodesFromList(ilist_traits<BasicBlock> &otherList,
-                             block_iterator first, block_iterator last);
-private:
-  mlir::CFGFunction *getContainingFunction();
-};
-} // end namespace llvm
-
-
-#endif  // MLIR_IR_BASICBLOCK_H
index 532564defb23d36e82f7e64cab90eb3991adb0a2..a632078903a601fce24b311b6319ef20df758b44 100644 (file)
@@ -18,8 +18,7 @@
 #ifndef MLIR_IR_BUILDERS_H
 #define MLIR_IR_BUILDERS_H
 
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/Statements.h"
 
 namespace mlir {
@@ -222,7 +221,7 @@ public:
   }
 
   void insert(Instruction *opInst) {
-    block->getOperations().insert(insertPoint, opInst);
+    block->getStatements().insert(insertPoint, opInst);
   }
 
   /// Add new basic block and set the insertion point to the end of it.  If an
@@ -232,7 +231,7 @@ public:
   BasicBlock *createBlock(BasicBlock *insertBefore = nullptr);
 
   /// Create an operation given the fields represented as an OperationState.
-  Instruction *createOperation(const OperationState &state);
+  OperationStmt *createOperation(const OperationState &state);
 
   /// Create operation of specific op type at the current insertion point
   /// without verifying to see if it is valid.
@@ -268,8 +267,8 @@ public:
     return OpPointer<OpTy>();
   }
 
-  Instruction *cloneOperation(const Instruction &srcOpInst) {
-    auto *op = srcOpInst.clone();
+  OperationStmt *cloneOperation(const OperationStmt &srcOpInst) {
+    auto *op = cast<OperationStmt>(srcOpInst.clone(getContext()));
     insert(op);
     return op;
   }
@@ -439,11 +438,11 @@ public:
       : Builder(mlFuncBuilder.getContext()), builder(mlFuncBuilder),
         kind(Function::Kind::MLFunc) {}
   FuncBuilder(Operation *op) : Builder(op->getContext()) {
-    if (auto *inst = dyn_cast<Instruction>(op)) {
-      builder = builderUnion(inst);
+    if (op->getOperationFunction()->isCFG()) {
+      builder = builderUnion(CFGFuncBuilder(cast<OperationInst>(op)));
       kind = Function::Kind::CFGFunc;
     } else {
-      builder = builderUnion(cast<OperationStmt>(op));
+      builder = builderUnion(MLFuncBuilder(cast<OperationStmt>(op)));
       kind = Function::Kind::MLFunc;
     }
   }
@@ -479,7 +478,7 @@ public:
   /// OperationStmt when building a ML function.
   void setInsertionPoint(Operation *op) {
     if (kind == Function::Kind::CFGFunc)
-      builder.cfg.setInsertionPoint(cast<Instruction>(op));
+      builder.cfg.setInsertionPoint(cast<OperationStmt>(op));
     else
       builder.ml.setInsertionPoint(cast<OperationStmt>(op));
   }
@@ -489,8 +488,6 @@ private:
   union builderUnion {
     builderUnion(CFGFuncBuilder cfg) : cfg(cfg) {}
     builderUnion(MLFuncBuilder ml) : ml(ml) {}
-    builderUnion(Instruction *op) : cfg(op) {}
-    builderUnion(OperationStmt *op) : ml(op) {}
     // Default initializer to allow deferring initialization of member.
     builderUnion() {}
 
index dd06846728e5744d0f9bd731c4963c8a3a2b0bbb..ec88e2d157be63d2af7aefc45aff408048a53567 100644 (file)
@@ -29,7 +29,6 @@
 #include "mlir/IR/OpDefinition.h"
 
 namespace mlir {
-class BasicBlock;
 class Builder;
 class MLValue;
 
diff --git a/mlir/include/mlir/IR/CFGFunction.h b/mlir/include/mlir/IR/CFGFunction.h
deleted file mode 100644 (file)
index b203fc8..0000000
+++ /dev/null
@@ -1,99 +0,0 @@
-//===- CFGFunction.h - MLIR CFGFunction Class -------------------*- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#ifndef MLIR_IR_CFGFUNCTION_H
-#define MLIR_IR_CFGFUNCTION_H
-
-#include "mlir/IR/BasicBlock.h"
-#include "mlir/IR/Function.h"
-
-namespace mlir {
-
-// This kind of function is defined in terms of a "Control Flow Graph" of basic
-// blocks, each of which includes instructions.
-class CFGFunction : public Function {
-public:
-  CFGFunction(Location location, StringRef name, FunctionType type,
-              ArrayRef<NamedAttribute> attrs = {});
-
-  ~CFGFunction();
-
-  //===--------------------------------------------------------------------===//
-  // BasicBlock list management
-  //===--------------------------------------------------------------------===//
-
-  /// This is the list of blocks in the function.
-  using BasicBlockListType = llvm::iplist<BasicBlock>;
-  BasicBlockListType &getBlocks() { return blocks; }
-  const BasicBlockListType &getBlocks() const { return blocks; }
-
-  // Iteration over the block in the function.
-  using iterator = BasicBlockListType::iterator;
-  using const_iterator = BasicBlockListType::const_iterator;
-  using reverse_iterator = BasicBlockListType::reverse_iterator;
-  using const_reverse_iterator = BasicBlockListType::const_reverse_iterator;
-
-  iterator begin() { return blocks.begin(); }
-  iterator end() { return blocks.end(); }
-  const_iterator begin() const { return blocks.begin(); }
-  const_iterator end() const { return blocks.end(); }
-  reverse_iterator rbegin() { return blocks.rbegin(); }
-  reverse_iterator rend() { return blocks.rend(); }
-  const_reverse_iterator rbegin() const { return blocks.rbegin(); }
-  const_reverse_iterator rend() const { return blocks.rend(); }
-
-  bool empty() const { return blocks.empty(); }
-  void push_back(BasicBlock *block) { blocks.push_back(block); }
-  void push_front(BasicBlock *block) { blocks.push_front(block); }
-
-  BasicBlock &back() { return blocks.back(); }
-  const BasicBlock &back() const {
-    return const_cast<CFGFunction *>(this)->back();
-  }
-
-  BasicBlock &front() { return blocks.front(); }
-  const BasicBlock &front() const {
-    return const_cast<CFGFunction*>(this)->front();
-  }
-
-  //===--------------------------------------------------------------------===//
-  // Other
-  //===--------------------------------------------------------------------===//
-
-  /// getSublistAccess() - Returns pointer to member of block list
-  static BasicBlockListType CFGFunction::*getSublistAccess(BasicBlock*) {
-    return &CFGFunction::blocks;
-  }
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const Function *func) {
-    return func->getKind() == Kind::CFGFunc;
-  }
-
-  /// Displays the CFG in a window. This is for use from the debugger and
-  /// depends on Graphviz to generate the graph.
-  /// This function is defined in CFGFunctionViewGraph and only works with that
-  /// target linked.
-  void viewGraph() const;
-
-private:
-  BasicBlockListType blocks;
-};
-
-} // end namespace mlir
-
-#endif  // MLIR_IR_CFGFUNCTION_H
diff --git a/mlir/include/mlir/IR/CFGValue.h b/mlir/include/mlir/IR/CFGValue.h
deleted file mode 100644 (file)
index 4fefe97..0000000
+++ /dev/null
@@ -1,124 +0,0 @@
-//===- CFGValue.h - CFGValue base class and SSA type decls ------*- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file defines SSA manipulation implementations for CFG functions.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_CFGVALUE_H
-#define MLIR_IR_CFGVALUE_H
-
-#include "mlir/IR/SSAValue.h"
-
-namespace mlir {
-class BasicBlock;
-class CFGValue;
-class CFGFunction;
-class Instruction;
-
-/// This enum contains all of the SSA value kinds that are valid in a CFG
-/// function.  This should be kept as a proper subtype of SSAValueKind,
-/// including having all of the values of the enumerators align.
-enum class CFGValueKind {
-  BBArgument = (int)SSAValueKind::BBArgument,
-  InstResult = (int)SSAValueKind::InstResult,
-};
-
-/// The operand of a CFG Instruction contains a CFGValue.
-using InstOperand = IROperandImpl<CFGValue, Instruction>;
-
-/// CFGValue is the base class for SSA values in CFG functions.
-class CFGValue : public SSAValueImpl<InstOperand, Instruction, CFGValueKind> {
-public:
-  static bool classof(const SSAValue *value) {
-    switch (value->getKind()) {
-    case SSAValueKind::BBArgument:
-    case SSAValueKind::InstResult:
-      return true;
-
-    case SSAValueKind::BlockArgument:
-    case SSAValueKind::StmtResult:
-    case SSAValueKind::ForStmt:
-      return false;
-    }
-  }
-
-  /// Return the function that this CFGValue is defined in.
-  CFGFunction *getFunction();
-
-  /// Return the function that this CFGValue is defined in.
-  const CFGFunction *getFunction() const {
-    return const_cast<CFGValue *>(this)->getFunction();
-  }
-
-protected:
-  CFGValue(CFGValueKind kind, Type type) : SSAValueImpl(kind, type) {}
-};
-
-/// Basic block arguments are CFG Values.
-class BBArgument : public CFGValue {
-public:
-  static bool classof(const SSAValue *value) {
-    return value->getKind() == SSAValueKind::BBArgument;
-  }
-
-  /// Return the function that this argument is defined in.
-  CFGFunction *getFunction();
-  const CFGFunction *getFunction() const {
-    return const_cast<BBArgument *>(this)->getFunction();
-  }
-
-  BasicBlock *getOwner() { return owner; }
-  const BasicBlock *getOwner() const { return owner; }
-
-private:
-  friend class BasicBlock; // For access to private constructor.
-  BBArgument(Type type, BasicBlock *owner)
-      : CFGValue(CFGValueKind::BBArgument, type), owner(owner) {}
-
-  /// The owner of this operand.
-  /// TODO: can encode this more efficiently to avoid the space hit of this
-  /// through bitpacking shenanigans.
-  BasicBlock *const owner;
-};
-
-/// Instruction results are CFG Values.
-class InstResult : public CFGValue {
-public:
-  InstResult(Type type, Instruction *owner)
-      : CFGValue(CFGValueKind::InstResult, type), owner(owner) {}
-
-  static bool classof(const SSAValue *value) {
-    return value->getKind() == SSAValueKind::InstResult;
-  }
-
-  Instruction *getOwner() { return owner; }
-  const Instruction *getOwner() const { return owner; }
-
-  /// Return the number of this result.
-  unsigned getResultNumber() const;
-
-private:
-  /// The owner of this operand.
-  /// TODO: can encode this more efficiently to avoid the space hit of this
-  /// through bitpacking shenanigans.
-  Instruction *const owner;
-};
-
-} // namespace mlir
-
-#endif
index 4dd0a4641ef306f28f5b1f6c266ce9c5f8d464c5..9f2e202857f8d2a83d6390f841054c2b2d3c6fda 100644 (file)
@@ -27,6 +27,9 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Identifier.h"
 #include "mlir/IR/Location.h"
+#include "mlir/IR/MLValue.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/StmtBlock.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ilist.h"
@@ -36,6 +39,7 @@ class AttributeListStorage;
 class FunctionType;
 class MLIRContext;
 class Module;
+template <typename ObjectType, typename ElementType> class ArgumentIterator;
 
 /// NamedAttribute is used for function attribute lists, it holds an
 /// identifier for the name and a value for the attribute.  The attribute
@@ -47,8 +51,15 @@ class Function : public llvm::ilist_node_with_parent<Function, Module> {
 public:
   enum class Kind { ExtFunc, CFGFunc, MLFunc };
 
+  Function(Kind kind, Location location, StringRef name, FunctionType type,
+           ArrayRef<NamedAttribute> attrs = {});
+  ~Function();
+
   Kind getKind() const { return (Kind)nameAndKind.getInt(); }
 
+  bool isCFG() const { return getKind() == Kind::CFGFunc; }
+  bool isML() const { return getKind() == Kind::MLFunc; }
+
   /// The source location the operation was defined or derived from.
   Location getLoc() const { return location; }
 
@@ -65,11 +76,103 @@ public:
   Module *getModule() { return module; }
   const Module *getModule() const { return module; }
 
-  /// Unlink this instruction from its module and delete it.
-  void eraseFromModule();
+  /// Unlink this function from its module and delete it.
+  void erase();
+
+  //===--------------------------------------------------------------------===//
+  // Body Handling
+  //===--------------------------------------------------------------------===//
+
+  StmtBlockList &getBlockList() { return blocks; }
+  const StmtBlockList &getBlockList() const { return blocks; }
+
+  /// This is the list of blocks in the function.
+  using BlockListType = llvm::iplist<BasicBlock>;
+  BlockListType &getBlocks() { return blocks.getBlocks(); }
+  const BlockListType &getBlocks() const { return blocks.getBlocks(); }
+
+  // Iteration over the block in the function.
+  using iterator = BlockListType::iterator;
+  using const_iterator = BlockListType::const_iterator;
+  using reverse_iterator = BlockListType::reverse_iterator;
+  using const_reverse_iterator = BlockListType::const_reverse_iterator;
+
+  iterator begin() { return blocks.begin(); }
+  iterator end() { return blocks.end(); }
+  const_iterator begin() const { return blocks.begin(); }
+  const_iterator end() const { return blocks.end(); }
+  reverse_iterator rbegin() { return blocks.rbegin(); }
+  reverse_iterator rend() { return blocks.rend(); }
+  const_reverse_iterator rbegin() const { return blocks.rbegin(); }
+  const_reverse_iterator rend() const { return blocks.rend(); }
+
+  bool empty() const { return blocks.empty(); }
+  void push_back(BasicBlock *block) { blocks.push_back(block); }
+  void push_front(BasicBlock *block) { blocks.push_front(block); }
+
+  BasicBlock &back() { return blocks.back(); }
+  const BasicBlock &back() const {
+    return const_cast<CFGFunction *>(this)->back();
+  }
+
+  BasicBlock &front() { return blocks.front(); }
+  const BasicBlock &front() const {
+    return const_cast<CFGFunction *>(this)->front();
+  }
+
+  /// Return the 'return' statement of this MLFunction.
+  const OperationStmt *getReturnStmt() const;
+  OperationStmt *getReturnStmt();
 
-  /// Delete this object.
-  void destroy();
+  // These should only be used on MLFunctions.
+  StmtBlock *getBody() {
+    assert(isML());
+    return &blocks.front();
+  }
+  const StmtBlock *getBody() const {
+    return const_cast<Function *>(this)->getBody();
+  }
+
+  /// Walk the statements in the function in preorder, calling the callback for
+  /// each Operation statement.
+  void walk(std::function<void(OperationStmt *)> callback);
+
+  /// Walk the statements in the function in postorder, calling the callback for
+  /// each Operation statement.
+  void walkPostOrder(std::function<void(OperationStmt *)> callback);
+
+  //===--------------------------------------------------------------------===//
+  // Arguments
+  //===--------------------------------------------------------------------===//
+
+  /// Returns number of arguments.
+  unsigned getNumArguments() const { return getType().getInputs().size(); }
+
+  /// Gets argument.
+  BlockArgument *getArgument(unsigned idx) {
+    return getBlocks().front().getArgument(idx);
+  }
+
+  const BlockArgument *getArgument(unsigned idx) const {
+    return getBlocks().front().getArgument(idx);
+  }
+
+  // Supports non-const operand iteration.
+  using args_iterator = ArgumentIterator<MLFunction, BlockArgument>;
+  args_iterator args_begin();
+  args_iterator args_end();
+  llvm::iterator_range<args_iterator> getArguments();
+
+  // Supports const operand iteration.
+  using const_args_iterator =
+      ArgumentIterator<const MLFunction, const BlockArgument>;
+  const_args_iterator args_begin() const;
+  const_args_iterator args_end() const;
+  llvm::iterator_range<const_args_iterator> getArguments() const;
+
+  //===--------------------------------------------------------------------===//
+  // Other
+  //===--------------------------------------------------------------------===//
 
   /// Perform (potentially expensive) checks of invariants, used to detect
   /// compiler bugs.  On error, this reports the error through the MLIRContext
@@ -93,10 +196,11 @@ public:
   /// handlers that may be listening.
   void emitNote(const Twine &message) const;
 
-protected:
-  Function(Kind kind, Location location, StringRef name, FunctionType type,
-           ArrayRef<NamedAttribute> attrs = {});
-  ~Function();
+  /// Displays the CFG in a window. This is for use from the debugger and
+  /// depends on Graphviz to generate the graph.
+  /// This function is defined in CFGFunctionViewGraph and only works with that
+  /// target linked.
+  void viewGraph() const;
 
 private:
   /// The name of the function and the kind of function this is.
@@ -114,23 +218,70 @@ private:
   /// This holds general named attributes for the function.
   AttributeListStorage *attrs;
 
+  /// The contents of the body.
+  StmtBlockList blocks;
+
   void operator=(const Function &) = delete;
   friend struct llvm::ilist_traits<Function>;
 };
 
-/// An extfunc declaration is a declaration of a function signature that is
-/// defined in some other module.
-class ExtFunction : public Function {
+//===--------------------------------------------------------------------===//
+// ArgumentIterator
+//===--------------------------------------------------------------------===//
+
+/// This template implements the argument iterator in terms of getArgument(idx).
+template <typename ObjectType, typename ElementType>
+class ArgumentIterator final
+    : public IndexedAccessorIterator<ArgumentIterator<ObjectType, ElementType>,
+                                     ObjectType, ElementType> {
 public:
-  ExtFunction(Location location, StringRef name, FunctionType type,
-              ArrayRef<NamedAttribute> attrs = {});
+  /// Initializes the result iterator to the specified index.
+  ArgumentIterator(ObjectType *object, unsigned index)
+      : IndexedAccessorIterator<ArgumentIterator<ObjectType, ElementType>,
+                                ObjectType, ElementType>(object, index) {}
+
+  /// Support converting to the const variant. This will be a no-op for const
+  /// variant.
+  operator ArgumentIterator<const ObjectType, const ElementType>() const {
+    return ArgumentIterator<const ObjectType, const ElementType>(this->object,
+                                                                 this->index);
+  }
 
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const Function *func) {
-    return func->getKind() == Kind::ExtFunc;
+  ElementType *operator*() const {
+    return this->object->getArgument(this->index);
   }
 };
 
+//===--------------------------------------------------------------------===//
+// MLFunction iterator methods.
+//===--------------------------------------------------------------------===//
+
+inline MLFunction::args_iterator MLFunction::args_begin() {
+  return args_iterator(this, 0);
+}
+
+inline MLFunction::args_iterator MLFunction::args_end() {
+  return args_iterator(this, getNumArguments());
+}
+
+inline llvm::iterator_range<MLFunction::args_iterator>
+MLFunction::getArguments() {
+  return {args_begin(), args_end()};
+}
+
+inline MLFunction::const_args_iterator MLFunction::args_begin() const {
+  return const_args_iterator(this, 0);
+}
+
+inline MLFunction::const_args_iterator MLFunction::args_end() const {
+  return const_args_iterator(this, getNumArguments());
+}
+
+inline llvm::iterator_range<MLFunction::const_args_iterator>
+MLFunction::getArguments() const {
+  return {args_begin(), args_end()};
+}
+
 } // end namespace mlir
 
 //===----------------------------------------------------------------------===//
@@ -145,7 +296,7 @@ struct ilist_traits<::mlir::Function>
   using Function = ::mlir::Function;
   using function_iterator = simple_ilist<Function>::iterator;
 
-  static void deleteNode(Function *function) { function->destroy(); }
+  static void deleteNode(Function *function) { delete function; }
 
   void addNodeToList(Function *function);
   void removeNodeFromList(Function *function);
index 9bcf4fde442bdee318f03cb2e7bff957ee02cf33..95b20476880f4b068e65bb1e95e388213cb040ce 100644 (file)
@@ -24,8 +24,7 @@
 #ifndef MLIR_IR_CFGFUNCTIONGRAPHTRAITS_H
 #define MLIR_IR_CFGFUNCTIONGRAPHTRAITS_H
 
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Function.h"
 #include "llvm/ADT/GraphTraits.h"
 
 namespace llvm {
@@ -154,135 +153,71 @@ struct GraphTraits<Inverse<const mlir::CFGFunction *>>
   }
 };
 
-template <> struct GraphTraits<mlir::StmtBlock *> {
-  using ChildIteratorType = mlir::StmtBlock::succ_iterator;
-  using Node = mlir::StmtBlock;
-  using NodeRef = Node *;
-
-  static NodeRef getEntryNode(NodeRef bb) { return bb; }
-
-  static ChildIteratorType child_begin(NodeRef node) {
-    return node->succ_begin();
-  }
-  static ChildIteratorType child_end(NodeRef node) { return node->succ_end(); }
-};
-
-template <> struct GraphTraits<const mlir::StmtBlock *> {
-  using ChildIteratorType = mlir::StmtBlock::const_succ_iterator;
-  using Node = const mlir::StmtBlock;
-  using NodeRef = Node *;
-
-  static NodeRef getEntryNode(NodeRef bb) { return bb; }
-
-  static ChildIteratorType child_begin(NodeRef node) {
-    return node->succ_begin();
-  }
-  static ChildIteratorType child_end(NodeRef node) { return node->succ_end(); }
-};
-
-template <> struct GraphTraits<Inverse<mlir::StmtBlock *>> {
-  using ChildIteratorType = mlir::StmtBlock::pred_iterator;
-  using Node = mlir::StmtBlock;
-  using NodeRef = Node *;
-  static NodeRef getEntryNode(Inverse<NodeRef> inverseGraph) {
-    return inverseGraph.Graph;
-  }
-  static inline ChildIteratorType child_begin(NodeRef node) {
-    return node->pred_begin();
-  }
-  static inline ChildIteratorType child_end(NodeRef node) {
-    return node->pred_end();
-  }
-};
-
-template <> struct GraphTraits<Inverse<const mlir::StmtBlock *>> {
-  using ChildIteratorType = mlir::StmtBlock::const_pred_iterator;
-  using Node = const mlir::StmtBlock;
-  using NodeRef = Node *;
-
-  static NodeRef getEntryNode(Inverse<NodeRef> inverseGraph) {
-    return inverseGraph.Graph;
-  }
-  static inline ChildIteratorType child_begin(NodeRef node) {
-    return node->pred_begin();
-  }
-  static inline ChildIteratorType child_end(NodeRef node) {
-    return node->pred_end();
-  }
-};
-
 template <>
-struct GraphTraits<mlir::MLFunction *> : public GraphTraits<mlir::StmtBlock *> {
-  using GraphType = mlir::MLFunction *;
-  using NodeRef = mlir::StmtBlock *;
+struct GraphTraits<mlir::StmtBlockList *>
+    : public GraphTraits<mlir::BasicBlock *> {
+  using GraphType = mlir::StmtBlockList *;
+  using NodeRef = mlir::BasicBlock *;
 
-  static NodeRef getEntryNode(GraphType fn) {
-    return &fn->getBlockList().front();
-  }
+  static NodeRef getEntryNode(GraphType fn) { return &fn->front(); }
 
-  using nodes_iterator = pointer_iterator<mlir::StmtBlockList::iterator>;
+  using nodes_iterator = pointer_iterator<mlir::CFGFunction::iterator>;
   static nodes_iterator nodes_begin(GraphType fn) {
-    return nodes_iterator(fn->getBlockList().begin());
+    return nodes_iterator(fn->begin());
   }
   static nodes_iterator nodes_end(GraphType fn) {
-    return nodes_iterator(fn->getBlockList().end());
+    return nodes_iterator(fn->end());
   }
 };
 
 template <>
-struct GraphTraits<const mlir::MLFunction *>
-    : public GraphTraits<const mlir::StmtBlock *> {
-  using GraphType = const mlir::MLFunction *;
-  using NodeRef = const mlir::StmtBlock *;
+struct GraphTraits<const mlir::StmtBlockList *>
+    : public GraphTraits<const mlir::BasicBlock *> {
+  using GraphType = const mlir::StmtBlockList *;
+  using NodeRef = const mlir::BasicBlock *;
 
-  static NodeRef getEntryNode(GraphType fn) {
-    return &fn->getBlockList().front();
-  }
+  static NodeRef getEntryNode(GraphType fn) { return &fn->front(); }
 
-  using nodes_iterator = pointer_iterator<mlir::StmtBlockList::const_iterator>;
+  using nodes_iterator = pointer_iterator<mlir::CFGFunction::const_iterator>;
   static nodes_iterator nodes_begin(GraphType fn) {
-    return nodes_iterator(fn->getBlockList().begin());
+    return nodes_iterator(fn->begin());
   }
   static nodes_iterator nodes_end(GraphType fn) {
-    return nodes_iterator(fn->getBlockList().end());
+    return nodes_iterator(fn->end());
   }
 };
 
 template <>
-struct GraphTraits<Inverse<mlir::MLFunction *>>
-    : public GraphTraits<Inverse<mlir::StmtBlock *>> {
-  using GraphType = Inverse<mlir::MLFunction *>;
+struct GraphTraits<Inverse<mlir::StmtBlockList *>>
+    : public GraphTraits<Inverse<mlir::BasicBlock *>> {
+  using GraphType = Inverse<mlir::StmtBlockList *>;
   using NodeRef = NodeRef;
 
-  static NodeRef getEntryNode(GraphType fn) {
-    return &fn.Graph->getBlockList().front();
-  }
+  static NodeRef getEntryNode(GraphType fn) { return &fn.Graph->front(); }
 
-  using nodes_iterator = pointer_iterator<mlir::StmtBlockList::iterator>;
+  using nodes_iterator = pointer_iterator<mlir::CFGFunction::iterator>;
   static nodes_iterator nodes_begin(GraphType fn) {
-    return nodes_iterator(fn.Graph->getBlockList().begin());
+    return nodes_iterator(fn.Graph->begin());
   }
   static nodes_iterator nodes_end(GraphType fn) {
-    return nodes_iterator(fn.Graph->getBlockList().end());
+    return nodes_iterator(fn.Graph->end());
   }
 };
 
 template <>
-struct GraphTraits<Inverse<const mlir::MLFunction *>>
-    : public GraphTraits<Inverse<const mlir::StmtBlock *>> {
-  using GraphType = Inverse<const mlir::MLFunction *>;
+struct GraphTraits<Inverse<const mlir::StmtBlockList *>>
+    : public GraphTraits<Inverse<const mlir::BasicBlock *>> {
+  using GraphType = Inverse<const mlir::StmtBlockList *>;
   using NodeRef = NodeRef;
 
-  static NodeRef getEntryNode(GraphType fn) {
-    return &fn.Graph->getBlockList().front();
-  }
+  static NodeRef getEntryNode(GraphType fn) { return &fn.Graph->front(); }
 
-  using nodes_iterator = pointer_iterator<mlir::StmtBlockList::const_iterator>;
+  using nodes_iterator = pointer_iterator<mlir::CFGFunction::const_iterator>;
   static nodes_iterator nodes_begin(GraphType fn) {
-    return nodes_iterator(fn.Graph->getBlockList().begin());
+    return nodes_iterator(fn.Graph->begin());
   }
   static nodes_iterator nodes_end(GraphType fn) {
-    return nodes_iterator(fn.Graph->getBlockList().end());
+    return nodes_iterator(fn.Graph->end());
   }
 };
 
diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h
deleted file mode 100644 (file)
index c63a6e7..0000000
+++ /dev/null
@@ -1,379 +0,0 @@
-//===- Instructions.h - MLIR CFG Instruction Classes ------------*- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file defines the classes for CFGFunction instructions.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_INSTRUCTIONS_H
-#define MLIR_IR_INSTRUCTIONS_H
-
-#include "mlir/IR/CFGValue.h"
-#include "mlir/IR/Identifier.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/ilist.h"
-#include "llvm/Support/TrailingObjects.h"
-
-namespace mlir {
-class BasicBlock;
-class CFGFunction;
-} // end namespace mlir
-
-//===----------------------------------------------------------------------===//
-// ilist_traits for Instruction
-//===----------------------------------------------------------------------===//
-
-namespace llvm {
-
-template <> struct ilist_traits<::mlir::Instruction> {
-  using Instruction = ::mlir::Instruction;
-  using instr_iterator = simple_ilist<Instruction>::iterator;
-
-  static void deleteNode(Instruction *inst);
-  void addNodeToList(Instruction *inst);
-  void removeNodeFromList(Instruction *inst);
-  void transferNodesFromList(ilist_traits<Instruction> &otherList,
-                             instr_iterator first, instr_iterator last);
-
-private:
-  mlir::BasicBlock *getContainingBlock();
-};
-
-} // end namespace llvm
-
-namespace mlir {
-
-/// The operand of a Terminator contains a BasicBlock.
-using BasicBlockOperand = IROperandImpl<BasicBlock, Instruction>;
-
-// The trailing objects of an instruction are layed out as follows:
-//   - InstResult        : The results of the instruction.
-//   - InstOperand       : The operands of the instruction. For terminators, the
-//                         operands for all successors are concatenated here.
-//   - BasicBlockOperand : Use-list of successor blocks if this is a terminator.
-//   - unsigned          : Count of operands held for each of the successors.
-//
-// Note: For Terminators, we rely on the assumption that all non-successor
-// operands are placed at the beginning of the operands list.
-class Instruction final
-    : public Operation,
-      public IROperandOwner,
-      public llvm::ilist_node_with_parent<Instruction, BasicBlock>,
-      private llvm::TrailingObjects<Instruction, InstResult, InstOperand,
-                                    BasicBlockOperand, unsigned> {
-public:
-  using IROperandOwner::getContext;
-  using IROperandOwner::getLoc;
-  using IROperandOwner::setLoc;
-
-  //===--------------------------------------------------------------------===//
-  // Operands
-  //===--------------------------------------------------------------------===//
-
-  unsigned getNumOperands() const { return numOperands; }
-
-  CFGValue *getOperand(unsigned idx) { return getInstOperand(idx).get(); }
-  const CFGValue *getOperand(unsigned idx) const {
-    return getInstOperand(idx).get();
-  }
-  void setOperand(unsigned idx, CFGValue *value) {
-    getInstOperand(idx).set(value);
-  }
-
-  // Support non-const operand iteration.
-  using operand_iterator = OperandIterator<Instruction, CFGValue>;
-
-  operand_iterator operand_begin() { return operand_iterator(this, 0); }
-
-  operand_iterator operand_end() {
-    return operand_iterator(this, getNumOperands());
-  }
-
-  llvm::iterator_range<operand_iterator> getOperands() {
-    return {operand_begin(), operand_end()};
-  }
-
-  // Support const operand iteration.
-  using const_operand_iterator =
-      OperandIterator<const Instruction, const CFGValue>;
-
-  const_operand_iterator operand_begin() const {
-    return const_operand_iterator(this, 0);
-  }
-
-  const_operand_iterator operand_end() const {
-    return const_operand_iterator(this, getNumOperands());
-  }
-
-  llvm::iterator_range<const_operand_iterator> getOperands() const {
-    return {operand_begin(), operand_end()};
-  }
-
-  ArrayRef<InstOperand> getInstOperands() const {
-    return {getTrailingObjects<InstOperand>(), numOperands};
-  }
-
-  MutableArrayRef<InstOperand> getInstOperands() {
-    return {getTrailingObjects<InstOperand>(), numOperands};
-  }
-
-  // Accessors to InstOperand.
-  InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; }
-  const InstOperand &getInstOperand(unsigned idx) const {
-    return getInstOperands()[idx];
-  }
-
-  //===--------------------------------------------------------------------===//
-  // Results
-  //===--------------------------------------------------------------------===//
-
-  unsigned getNumResults() const { return numResults; }
-
-  CFGValue *getResult(unsigned idx) { return &getInstResult(idx); }
-  const CFGValue *getResult(unsigned idx) const { return &getInstResult(idx); }
-
-  // Support non-const result iteration.
-  using result_iterator = ResultIterator<Instruction, CFGValue>;
-  result_iterator result_begin() { return result_iterator(this, 0); }
-  result_iterator result_end() {
-    return result_iterator(this, getNumResults());
-  }
-  llvm::iterator_range<result_iterator> getResults() {
-    return {result_begin(), result_end()};
-  }
-
-  // Support const result iteration.
-  using const_result_iterator =
-      ResultIterator<const Instruction, const CFGValue>;
-  const_result_iterator result_begin() const {
-    return const_result_iterator(this, 0);
-  }
-
-  const_result_iterator result_end() const {
-    return const_result_iterator(this, getNumResults());
-  }
-
-  llvm::iterator_range<const_result_iterator> getResults() const {
-    return {result_begin(), result_end()};
-  }
-
-  ArrayRef<InstResult> getInstResults() const {
-    return {getTrailingObjects<InstResult>(), numResults};
-  }
-
-  MutableArrayRef<InstResult> getInstResults() {
-    return {getTrailingObjects<InstResult>(), numResults};
-  }
-
-  InstResult &getInstResult(unsigned idx) { return getInstResults()[idx]; }
-
-  const InstResult &getInstResult(unsigned idx) const {
-    return getInstResults()[idx];
-  }
-
-  // Support result type iteration.
-  using result_type_iterator =
-      ResultTypeIterator<const Instruction, const CFGValue>;
-  result_type_iterator result_type_begin() const {
-    return result_type_iterator(this, 0);
-  }
-
-  result_type_iterator result_type_end() const {
-    return result_type_iterator(this, getNumResults());
-  }
-
-  llvm::iterator_range<result_type_iterator> getResultTypes() const {
-    return {result_type_begin(), result_type_end()};
-  }
-
-  //===--------------------------------------------------------------------===//
-  // Terminators
-  //===--------------------------------------------------------------------===//
-
-  MutableArrayRef<BasicBlockOperand> getBasicBlockOperands() {
-    assert(isTerminator() && "Only terminators have a block operands list.");
-    return {getTrailingObjects<BasicBlockOperand>(), numSuccs};
-  }
-  ArrayRef<BasicBlockOperand> getBasicBlockOperands() const {
-    return const_cast<Instruction *>(this)->getBasicBlockOperands();
-  }
-
-  MutableArrayRef<InstOperand> getSuccessorInstOperands(unsigned index) {
-    assert(isTerminator() && "Only terminators have successors.");
-    assert(index < getNumSuccessors());
-    unsigned succOpIndex = getSuccessorOperandIndex(index);
-    auto *operandBegin = getInstOperands().data() + succOpIndex;
-    return {operandBegin, getNumSuccessorOperands(index)};
-  }
-  ArrayRef<InstOperand> getSuccessorInstOperands(unsigned index) const {
-    return const_cast<Instruction *>(this)->getSuccessorInstOperands(index);
-  }
-
-  unsigned getNumSuccessors() const { return getBasicBlockOperands().size(); }
-  unsigned getNumSuccessorOperands(unsigned index) const {
-    assert(isTerminator() && "Only terminators have successors.");
-    assert(index < getNumSuccessors());
-    return getTrailingObjects<unsigned>()[index];
-  }
-
-  BasicBlock *getSuccessor(unsigned index) {
-    assert(index < getNumSuccessors());
-    return getBasicBlockOperands()[index].get();
-  }
-  const BasicBlock *getSuccessor(unsigned index) const {
-    return const_cast<Instruction *>(this)->getSuccessor(index);
-  }
-  void setSuccessor(BasicBlock *block, unsigned index);
-
-  /// Erase a specific operand from the operand list of the successor at
-  /// 'index'.
-  void eraseSuccessorOperand(unsigned succIndex, unsigned opIndex) {
-    assert(succIndex < getNumSuccessors());
-    assert(opIndex < getNumSuccessorOperands(succIndex));
-    eraseOperand(getSuccessorOperandIndex(succIndex) + opIndex);
-    --getTrailingObjects<unsigned>()[succIndex];
-  }
-
-  /// Get the index of the first operand of the successor at the provided
-  /// index.
-  unsigned getSuccessorOperandIndex(unsigned index) const {
-    assert(isTerminator() && "Only terminators have successors.");
-    assert(index < getNumSuccessors());
-
-    // Count the number of operands for each of the successors after, and
-    // including, the one at 'index'. This is based upon the assumption that all
-    // non successor operands are placed at the beginning of the operand list.
-    auto *successorOpCountBegin = getTrailingObjects<unsigned>();
-    unsigned postSuccessorOpCount =
-        std::accumulate(successorOpCountBegin + index,
-                        successorOpCountBegin + getNumSuccessors(), 0);
-    return getNumOperands() - postSuccessorOpCount;
-  }
-
-  //===--------------------------------------------------------------------===//
-  // Other
-  //===--------------------------------------------------------------------===//
-
-  /// Create a new Instruction with the specified fields.
-  static Instruction *
-  create(Location location, OperationName name, ArrayRef<CFGValue *> operands,
-         ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes,
-         ArrayRef<BasicBlock *> successors, MLIRContext *context);
-
-  Instruction *clone() const;
-
-  /// Return the BasicBlock containing this instruction.
-  const BasicBlock *getBlock() const { return block; }
-  BasicBlock *getBlock() { return block; }
-
-  /// Return the CFGFunction containing this instruction.
-  CFGFunction *getFunction();
-  const CFGFunction *getFunction() const {
-    return const_cast<Instruction *>(this)->getFunction();
-  }
-
-  void print(raw_ostream &os) const;
-  void dump() const;
-
-  /// Unlink this instruction from its BasicBlock and delete it.
-  void erase();
-
-  /// Destroy this instruction and its subclass data.
-  void destroy();
-
-  /// Unlink this instruction from its current basic block and insert
-  /// it right before `existingInst` which may be in the same or another block
-  /// of the same function.
-  void moveBefore(Instruction *existingInst);
-
-  /// Unlink this instruction from its current basic block and insert
-  /// it right before `iterator` in the specified basic block, which must be in
-  /// the same function.
-  void moveBefore(BasicBlock *block,
-                  llvm::iplist<Instruction>::iterator iterator);
-
-  /// This drops all operand uses from this instruction, which is an essential
-  /// step in breaking cyclic dependences between references when they are to
-  /// be deleted.
-  void dropAllReferences();
-
-  /// Emit an error about fatal conditions with this operation, reporting up to
-  /// any diagnostic handlers that may be listening.  This function always
-  /// returns true.  NOTE: This may terminate the containing application, only
-  /// use when the IR is in an inconsistent state.
-  bool emitError(const Twine &message) const;
-
-  /// Emit a warning about this operation, reporting up to any diagnostic
-  /// handlers that may be listening.
-  void emitWarning(const Twine &message) const;
-
-  /// Emit a note about this operation, reporting up to any diagnostic
-  /// handlers that may be listening.
-  void emitNote(const Twine &message) const;
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const IROperandOwner *ptr) {
-    return ptr->getKind() == IROperandOwner::Kind::Instruction;
-  }
-  static bool classof(const Operation *op) {
-    return op->getOperationKind() == OperationKind::Instruction;
-  }
-
-private:
-  const unsigned numResults, numSuccs;
-  /// The number of operands tail-allocated, mutable to support erasing.
-  unsigned numOperands;
-  BasicBlock *block = nullptr;
-
-  Instruction(Location location, OperationName name, unsigned numResults,
-              unsigned numOperands, unsigned numSuccessors,
-              ArrayRef<NamedAttribute> attributes, MLIRContext *context);
-
-  // Instructions are deleted through the destroy() member because this class
-  // does not have a virtual destructor.
-  ~Instruction();
-
-  /// Erase the operand at 'index'.
-  void eraseOperand(unsigned index);
-
-  friend struct llvm::ilist_traits<Instruction>;
-  friend class BasicBlock;
-
-  // This stuff is used by the TrailingObjects template.
-  friend llvm::TrailingObjects<Instruction, InstResult, InstOperand,
-                               BasicBlockOperand, unsigned>;
-  size_t numTrailingObjects(OverloadToken<InstResult>) const {
-    return numResults;
-  }
-  size_t numTrailingObjects(OverloadToken<InstOperand>) const {
-    return numOperands;
-  }
-  size_t numTrailingObjects(OverloadToken<BasicBlockOperand>) const {
-    return numSuccs;
-  }
-  size_t numTrailingObjects(OverloadToken<unsigned>) const { return numSuccs; }
-};
-
-inline raw_ostream &operator<<(raw_ostream &os, const Instruction &inst) {
-  inst.print(os);
-  return os;
-}
-
-} // end namespace mlir
-
-#endif // MLIR_IR_INSTRUCTIONS_H
diff --git a/mlir/include/mlir/IR/MLFunction.h b/mlir/include/mlir/IR/MLFunction.h
deleted file mode 100644 (file)
index e864877..0000000
+++ /dev/null
@@ -1,170 +0,0 @@
-//===- MLFunction.h - MLIR MLFunction Class ---------------------*- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file defines MLFunction class
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_MLFUNCTION_H_
-#define MLIR_IR_MLFUNCTION_H_
-
-#include "mlir/IR/Function.h"
-#include "mlir/IR/MLValue.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/StmtBlock.h"
-#include "llvm/Support/TrailingObjects.h"
-
-namespace mlir {
-
-template <typename ObjectType, typename ElementType> class ArgumentIterator;
-
-// MLFunction is defined as a sequence of statements that may
-// include nested affine for loops, conditionals and operations.
-class MLFunction final : public Function {
-public:
-  MLFunction(Location location, StringRef name, FunctionType type,
-             ArrayRef<NamedAttribute> attrs = {});
-
-  // TODO(clattner): drop this, it is redundant.
-  static MLFunction *create(Location location, StringRef name,
-                            FunctionType type,
-                            ArrayRef<NamedAttribute> attrs = {}) {
-    return new MLFunction(location, name, type, attrs);
-  }
-
-  StmtBlockList &getBlockList() { return body; }
-  const StmtBlockList &getBlockList() const { return body; }
-
-  StmtBlock *getBody() { return &body.front(); }
-  const StmtBlock *getBody() const { return &body.front(); }
-
-  //===--------------------------------------------------------------------===//
-  // Arguments
-  //===--------------------------------------------------------------------===//
-
-  /// Returns number of arguments.
-  unsigned getNumArguments() const { return getType().getInputs().size(); }
-
-  /// Gets argument.
-  BlockArgument *getArgument(unsigned idx) {
-    return getBlockList().front().getArgument(idx);
-  }
-
-  const BlockArgument *getArgument(unsigned idx) const {
-    return getBlockList().front().getArgument(idx);
-  }
-
-  // Supports non-const operand iteration.
-  using args_iterator = ArgumentIterator<MLFunction, BlockArgument>;
-  args_iterator args_begin();
-  args_iterator args_end();
-  llvm::iterator_range<args_iterator> getArguments();
-
-  // Supports const operand iteration.
-  using const_args_iterator =
-      ArgumentIterator<const MLFunction, const BlockArgument>;
-  const_args_iterator args_begin() const;
-  const_args_iterator args_end() const;
-  llvm::iterator_range<const_args_iterator> getArguments() const;
-
-  //===--------------------------------------------------------------------===//
-  // Other
-  //===--------------------------------------------------------------------===//
-
-  ~MLFunction();
-
-  // Return the 'return' statement of this MLFunction.
-  const OperationStmt *getReturnStmt() const;
-  OperationStmt *getReturnStmt();
-
-  /// Walk the statements in the function in preorder, calling the callback for
-  /// each Operation statement.
-  void walk(std::function<void(OperationStmt *)> callback);
-
-  /// Walk the statements in the function in postorder, calling the callback for
-  /// each Operation statement.
-  void walkPostOrder(std::function<void(OperationStmt *)> callback);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const Function *func) {
-    return func->getKind() == Function::Kind::MLFunc;
-  }
-
-private:
-  StmtBlockList body;
-};
-
-//===--------------------------------------------------------------------===//
-// ArgumentIterator
-//===--------------------------------------------------------------------===//
-
-/// This template implements the argument iterator in terms of getArgument(idx).
-template <typename ObjectType, typename ElementType>
-class ArgumentIterator final
-    : public IndexedAccessorIterator<ArgumentIterator<ObjectType, ElementType>,
-                                     ObjectType, ElementType> {
-public:
-  /// Initializes the result iterator to the specified index.
-  ArgumentIterator(ObjectType *object, unsigned index)
-      : IndexedAccessorIterator<ArgumentIterator<ObjectType, ElementType>,
-                                ObjectType, ElementType>(object, index) {}
-
-  /// Support converting to the const variant. This will be a no-op for const
-  /// variant.
-  operator ArgumentIterator<const ObjectType, const ElementType>() const {
-    return ArgumentIterator<const ObjectType, const ElementType>(this->object,
-                                                                 this->index);
-  }
-
-  ElementType *operator*() const {
-    return this->object->getArgument(this->index);
-  }
-};
-
-//===--------------------------------------------------------------------===//
-// MLFunction iterator methods.
-//===--------------------------------------------------------------------===//
-
-inline MLFunction::args_iterator MLFunction::args_begin() {
-  return args_iterator(this, 0);
-}
-
-inline MLFunction::args_iterator MLFunction::args_end() {
-  return args_iterator(this, getNumArguments());
-}
-
-inline llvm::iterator_range<MLFunction::args_iterator>
-MLFunction::getArguments() {
-  return {args_begin(), args_end()};
-}
-
-inline MLFunction::const_args_iterator MLFunction::args_begin() const {
-  return const_args_iterator(this, 0);
-}
-
-inline MLFunction::const_args_iterator MLFunction::args_end() const {
-  return const_args_iterator(this, getNumArguments());
-}
-
-inline llvm::iterator_range<MLFunction::const_args_iterator>
-MLFunction::getArguments() const {
-  return {args_begin(), args_end()};
-}
-
-} // end namespace mlir
-
-#endif  // MLIR_IR_MLFUNCTION_H_
index 27921520e385c29b658c523a9122480cf5960af5..a1b5412affa8ef8cd68cde5c9752168081fbc4da 100644 (file)
@@ -27,7 +27,7 @@
 namespace mlir {
 class ForStmt;
 class MLValue;
-class MLFunction;
+using MLFunction = Function;
 class Statement;
 class StmtBlock;
 
@@ -58,10 +58,6 @@ public:
     case SSAValueKind::StmtResult:
     case SSAValueKind::ForStmt:
       return true;
-
-    case SSAValueKind::BBArgument:
-    case SSAValueKind::InstResult:
-      return false;
     }
   }
 
@@ -127,6 +123,11 @@ private:
   OperationStmt *const owner;
 };
 
+// TODO(clattner) clean all this up.
+using CFGValue = MLValue;
+using BBArgument = BlockArgument;
+using InstResult = StmtResult;
+
 } // namespace mlir
 
 #endif
index e21576460141ba71112f4d93a5c4765fb929320e..3c68bb0f30f7c0eca94fdb6f49f2a244b51d999f 100644 (file)
@@ -33,7 +33,6 @@
 #include <type_traits>
 
 namespace mlir {
-class BasicBlock;
 class Builder;
 
 namespace OpTrait {
index 794eb190cbe25ef2c0cd4f230ee0d2fa35c71175..93a2c3061d9777d690bb83f200eba8580fb27082 100644 (file)
@@ -24,7 +24,6 @@
 
 namespace mlir {
 class AttributeListStorage;
-class BasicBlock;
 template <typename OpType> class ConstOpPointer;
 template <typename OpType> class OpPointer;
 template <typename ObjectType, typename ElementType> class OperandIterator;
@@ -32,8 +31,9 @@ template <typename ObjectType, typename ElementType> class ResultIterator;
 template <typename ObjectType, typename ElementType> class ResultTypeIterator;
 class Function;
 class IROperandOwner;
-class Instruction;
 class Statement;
+class OperationStmt;
+using Instruction = Statement;
 
 /// Operations represent all of the arithmetic and other basic computation in
 /// MLIR.  This class is the common implementation details behind Instruction
index 308a84aa3b161db82ea00c92651239bd651c3eb0..eaaf927b642072c023003d5491a1a42771fa31e9 100644 (file)
@@ -31,7 +31,6 @@
 #include <memory>
 
 namespace mlir {
-class BasicBlock;
 class Dialect;
 class Operation;
 class OperationState;
@@ -43,6 +42,7 @@ class RewritePattern;
 class StmtBlock;
 class SSAValue;
 class Type;
+using BasicBlock = StmtBlock;
 
 /// This is a vector that owns the patterns inside of it.
 using OwningPatternList = std::vector<std::unique_ptr<Pattern>>;
@@ -208,10 +208,7 @@ struct OperationState {
   SmallVector<Type, 4> types;
   SmallVector<NamedAttribute, 4> attributes;
   /// Successors of this operation and their respective operands.
-  SmallVector<BasicBlock *, 1> successors;
-
-  // TODO: rename to successors when CFG and ML Functions are merged.
-  SmallVector<StmtBlock *, 1> successorsS;
+  SmallVector<StmtBlock *, 1> successors;
 
 public:
   OperationState(MLIRContext *context, Location location, StringRef name)
@@ -220,25 +217,15 @@ public:
   OperationState(MLIRContext *context, Location location, OperationName name)
       : context(context), location(location), name(name) {}
 
-  OperationState(MLIRContext *context, Location location, StringRef name,
-                 ArrayRef<SSAValue *> operands, ArrayRef<Type> types,
-                 ArrayRef<NamedAttribute> attributes = {},
-                 ArrayRef<BasicBlock *> successors = {})
-      : context(context), location(location), name(name, context),
-        operands(operands.begin(), operands.end()),
-        types(types.begin(), types.end()),
-        attributes(attributes.begin(), attributes.end()),
-        successors(successors.begin(), successors.end()) {}
-
   OperationState(MLIRContext *context, Location location, StringRef name,
                  ArrayRef<SSAValue *> operands, ArrayRef<Type> types,
                  ArrayRef<NamedAttribute> attributes,
-                 ArrayRef<StmtBlock *> successors)
+                 ArrayRef<StmtBlock *> successors = {})
       : context(context), location(location), name(name, context),
         operands(operands.begin(), operands.end()),
         types(types.begin(), types.end()),
         attributes(attributes.begin(), attributes.end()),
-        successorsS(successors.begin(), successors.end()) {}
+        successors(successors.begin(), successors.end()) {}
 
   void addOperands(ArrayRef<SSAValue *> newOperands) {
     assert(successors.empty() &&
@@ -260,7 +247,7 @@ public:
     attributes.push_back({name, attr});
   }
 
-  void addSuccessor(BasicBlock *successor, ArrayRef<SSAValue *> succOperands) {
+  void addSuccessor(StmtBlock *successor, ArrayRef<SSAValue *> succOperands) {
     successors.push_back(successor);
     // Insert a sentinal operand to mark a barrier between successor operands.
     operands.push_back(nullptr);
index 29df2181264b47e16e4e2b67a9204c0242459aea..5791cbfd17a5cf0a664794de01e12f904c967d80 100644 (file)
 
 namespace mlir {
 class Function;
-class Instruction;
 class OperationStmt;
 class Operation;
+class Statement;
+using Instruction = Statement;
+using OperationInst = OperationStmt;
 
 /// This enumerates all of the SSA value kinds in the MLIR system.
 enum class SSAValueKind {
-  BBArgument,     // basic block argument
-  InstResult,     // instruction result
   BlockArgument,  // Block argument
   StmtResult,     // statement result
   ForStmt,        // for statement induction variable
@@ -69,8 +69,8 @@ public:
 
   /// If this value is the result of an Instruction, return the instruction
   /// that defines it.
-  Instruction *getDefiningInst();
-  const Instruction *getDefiningInst() const {
+  OperationInst *getDefiningInst();
+  const OperationInst *getDefiningInst() const {
     return const_cast<SSAValue *>(this)->getDefiningInst();
   }
 
index 188002a646ca2414344e738286634834497942a3..c7eddaf8d3cadc3cee294c9bbcec35fd77901440 100644 (file)
@@ -30,7 +30,7 @@
 
 namespace mlir {
 class Location;
-class MLFunction;
+using MLFunction = Function;
 class StmtBlock;
 class ForStmt;
 class MLIRContext;
@@ -93,6 +93,7 @@ public:
   /// sub-statements to the corresponding statement that is copied, and adds
   /// those mappings to the map.
   Statement *clone(OperandMapTy &operandMap, MLIRContext *context) const;
+  Statement *clone(MLIRContext *context) const;
 
   /// Returns the statement block that contains this statement.
   StmtBlock *getBlock() const { return block; }
@@ -109,6 +110,11 @@ public:
   /// Destroys this statement and its subclass data.
   void destroy();
 
+  /// This drops all operand uses from this instruction, which is an essential
+  /// step in breaking cyclic dependences between references when they are to
+  /// be deleted.
+  void dropAllReferences();
+
   /// Unlink this statement from its current block and insert it right before
   /// `existingStmt` which may be in the same or another block in the same
   /// function.
@@ -118,6 +124,9 @@ public:
   /// it right before `iterator` in the specified basic block.
   void moveBefore(StmtBlock *block, llvm::iplist<Statement>::iterator iterator);
 
+  // Returns whether the Statement is a terminator.
+  bool isTerminator() const;
+
   void print(raw_ostream &os) const;
   void dump() const;
 
index d653f59a26196739815b76eb8dadfaed0e18188a..b1c03e948db1e19f6c5c54d33b350095993e5b45 100644 (file)
@@ -32,6 +32,7 @@ class AffineBound;
 class IntegerSet;
 class AffineCondition;
 class OperationStmt;
+using OperationInst = OperationStmt;
 
 /// Operation statements represent operations inside ML functions.
 class OperationStmt final
@@ -49,6 +50,7 @@ public:
   /// Return the context this operation is associated with.
   MLIRContext *getContext() const;
 
+  using Operation::isTerminator;
   using Statement::dump;
   using Statement::emitError;
   using Statement::emitNote;
@@ -220,6 +222,15 @@ public:
   }
   void setSuccessor(BasicBlock *block, unsigned index);
 
+  /// Erase a specific operand from the operand list of the successor at
+  /// 'index'.
+  void eraseSuccessorOperand(unsigned succIndex, unsigned opIndex) {
+    assert(succIndex < getNumSuccessors());
+    assert(opIndex < getNumSuccessorOperands(succIndex));
+    eraseOperand(getSuccessorOperandIndex(succIndex) + opIndex);
+    --getTrailingObjects<unsigned>()[succIndex];
+  }
+
   /// Get the index of the first operand of the successor at the provided
   /// index.
   unsigned getSuccessorOperandIndex(unsigned index) const {
@@ -252,13 +263,17 @@ public:
   }
 
 private:
-  const unsigned numOperands, numResults, numSuccs;
+  unsigned numOperands;
+  const unsigned numResults, numSuccs;
 
   OperationStmt(Location location, OperationName name, unsigned numOperands,
                 unsigned numResults, unsigned numSuccessors,
                 ArrayRef<NamedAttribute> attributes, MLIRContext *context);
   ~OperationStmt();
 
+  /// Erase the operand at 'index'.
+  void eraseOperand(unsigned index);
+
   // This stuff is used by the TrailingObjects template.
   friend llvm::TrailingObjects<OperationStmt, StmtResult, StmtBlockOperand,
                                unsigned, StmtOperand>;
index fe4d5f417cce08bad4ea36d73d1a4d5a30253556..23b682043ac680c65d74b1592143280d4c3c9837 100644 (file)
 #include "mlir/IR/Statement.h"
 
 namespace mlir {
-class MLFunction;
 class IfStmt;
 class MLValue;
 class StmtBlockList;
+using CFGFunction = Function;
+using MLFunction = Function;
 
 // TODO(clattner): drop the Stmt prefixes on these once BasicBlock's versions of
 // these go away.
@@ -218,14 +219,40 @@ public:
   succ_iterator succ_end();
   llvm::iterator_range<succ_iterator> getSuccessors();
 
+  //===--------------------------------------------------------------------===//
+  // Other
+  //===--------------------------------------------------------------------===//
+
+  /// Unlink this BasicBlock from its CFGFunction and delete it.
+  void eraseFromFunction();
+
+  /// Split the basic block into two basic blocks before the specified
+  /// instruction or iterator.
+  ///
+  /// Note that all instructions BEFORE the specified iterator stay as part of
+  /// the original basic block, an unconditional branch is added to the original
+  /// block (going to the new block), and the rest of the instructions in the
+  /// original block are moved to the new BB, including the old terminator.  The
+  /// newly formed BasicBlock is returned.
+  ///
+  /// This function invalidates the specified iterator.
+  BasicBlock *splitBasicBlock(iterator splitBefore);
+  BasicBlock *splitBasicBlock(Instruction *splitBeforeInst) {
+    return splitBasicBlock(iterator(splitBeforeInst));
+  }
+
   /// getSublistAccess() - Returns pointer to member of statement list
   static StmtListType StmtBlock::*getSublistAccess(Statement *) {
     return &StmtBlock::statements;
   }
 
-  /// These have unconventional names to avoid derive class ambiguities.
-  void printBlock(raw_ostream &os) const;
-  void dumpBlock() const;
+  void print(raw_ostream &os) const;
+  void dump() const;
+
+  /// Print out the name of the basic block without printing its body.
+  /// NOTE: The printType argument is ignored.  We keep it for compatibility
+  /// with LLVM dominator machinery that expects it to exist.
+  void printAsOperand(raw_ostream &os, bool printType = true);
 
 private:
   /// This is the parent function/IfStmt/ForStmt that owns this block.
@@ -273,7 +300,7 @@ namespace mlir {
 /// is part of - an MLFunction or IfStmt or ForStmt.
 class StmtBlockList {
 public:
-  explicit StmtBlockList(MLFunction *container);
+  explicit StmtBlockList(Function *container);
   explicit StmtBlockList(Statement *container);
 
   using BlockListType = llvm::iplist<StmtBlock>;
@@ -314,25 +341,32 @@ public:
     return &StmtBlockList::blocks;
   }
 
-  /// A StmtBlockList is part of a MLFunction or and IfStmt/ForStmt.  If it is
+  /// A StmtBlockList is part of a Function or and IfStmt/ForStmt.  If it is
   /// part of an IfStmt/ForStmt, then return it, otherwise return null.
   Statement *getContainingStmt();
   const Statement *getContainingStmt() const {
     return const_cast<StmtBlockList *>(this)->getContainingStmt();
   }
 
-  /// A StmtBlockList is part of a MLFunction or and IfStmt/ForStmt.  If it is
-  /// part of an MLFunction, then return it, otherwise return null.
-  MLFunction *getContainingFunction();
-  const MLFunction *getContainingFunction() const {
+  /// A StmtBlockList is part of a Function or and IfStmt/ForStmt.  If it is
+  /// part of an Function, then return it, otherwise return null.
+  Function *getContainingFunction();
+  const Function *getContainingFunction() const {
     return const_cast<StmtBlockList *>(this)->getContainingFunction();
   }
 
+  // TODO(clattner): This is only to help ML -> CFG migration, remove in the
+  // near future.  This makes StmtBlockList work more like BasicBlock did.
+  CFGFunction *getFunction();
+  const CFGFunction *getFunction() const {
+    return const_cast<StmtBlockList *>(this)->getFunction();
+  }
+
 private:
   BlockListType blocks;
 
   /// This is the object we are part of.
-  llvm::PointerUnion<MLFunction *, Statement *> container;
+  llvm::PointerUnion<Function *, Statement *> container;
 };
 
 //===----------------------------------------------------------------------===//
index 8dcd5863096a90094ecad98cd619f8e23454febe..a0f787fea4d0d30099f12e0d5002f621ff265ab7 100644 (file)
@@ -65,7 +65,7 @@
 #ifndef MLIR_IR_STMTVISITOR_H
 #define MLIR_IR_STMTVISITOR_H
 
-#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/Statements.h"
 
 namespace mlir {
index 6aaecb2f34a3be2feda009d7568f73ff5c1f69b5..75f682d3c1014c5e132b9c6e311c5bb3c8ba3550 100644 (file)
@@ -25,8 +25,8 @@
 
 namespace mlir {
 class Function;
-class CFGFunction;
-class MLFunction;
+using CFGFunction = Function;
+using MLFunction = Function;
 class Module;
 
 // Values that can be used by to signal success/failure. This can be implicitly
index bce4c5e0afdd6b23db9ebadb11f4d52c91bea253..0c3c17d22134e1cbce97bfff2cafbc733cfc561d 100644 (file)
@@ -22,7 +22,7 @@
 #ifndef MLIR_TRANSFORMS_CFGFUNCTIONVIEWGRAPH_H_
 #define MLIR_TRANSFORMS_CFGFUNCTIONVIEWGRAPH_H_
 
-#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/Function.h"
 #include "mlir/Pass.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/GraphWriter.h"
index cc28cd3bad655988891a45b0bba539c179dc24f8..82213bcdb99a14cb6a5b660951763346f8d7f490 100644 (file)
@@ -30,7 +30,8 @@ namespace mlir {
 
 class AffineMap;
 class ForStmt;
-class MLFunction;
+class Function;
+using MLFunction = Function;
 class MLFuncBuilder;
 
 // Values that can be used to signal success/failure. This can be implicitly
index 7fe4b8a0a0652dcf695c0b5608812e762ccb5bdf..119f2add54ad4937d18bc4fab08dd9365f927684 100644 (file)
@@ -32,7 +32,6 @@
 
 namespace mlir {
 
-class CFGFunction;
 class ForStmt;
 class FuncBuilder;
 class Location;
@@ -40,6 +39,8 @@ class MLValue;
 class Module;
 class OperationStmt;
 class SSAValue;
+class Function;
+using CFGFunction = Function;
 
 /// Replace all uses of oldMemRef with newMemRef while optionally remapping the
 /// old memref's indices using the supplied affine map and adding any additional
index 499d6351b33dd01ae43331ee1583eddc429a31b3..b3faaf3eae08c250d232cad774980a46b5f72d70 100644 (file)
@@ -21,6 +21,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/Dominance.h"
+#include "mlir/IR/Statements.h"
 #include "llvm/Support/GenericDomTreeConstruction.h"
 using namespace mlir;
 
@@ -31,7 +32,7 @@ template class llvm::DomTreeNodeBase<BasicBlock>;
 /// Compute the immediate-dominators map.
 DominanceInfo::DominanceInfo(CFGFunction *function) : DominatorTreeBase() {
   // Build the dominator tree for the function.
-  recalculate(*function);
+  recalculate(function->getBlockList());
 }
 
 /// Return true if instruction A properly dominates instruction B.
index 04fe8114c3b841703d622ab5c3cb027fb1010ad8..d9a0edd6d833a4fabbd4b6e6e22f6d10955ca696 100644 (file)
@@ -15,8 +15,7 @@
 // limitations under the License.
 // =============================================================================
 
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/Statements.h"
 #include "mlir/IR/StmtVisitor.h"
@@ -65,7 +64,8 @@ PassResult PrintOpStatsPass::runOnModule(Module *m) {
 PassResult PrintOpStatsPass::runOnCFGFunction(CFGFunction *function) {
   for (const auto &bb : *function)
     for (const auto &inst : bb)
-      ++opCount[inst.getName().getStringRef()];
+      if (auto *op = dyn_cast<OperationInst>(&inst))
+        ++opCount[op->getName().getStringRef()];
   return success();
 }
 
index ca59a5f899838da4066c74b5862200883791e61f..0aefc10404f06b760e4876108e2e8e09d436d303 100644 (file)
@@ -20,8 +20,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Pass.h"
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/Module.h"
 #include "mlir/Support/PassNameParser.h"
 #include "llvm/ADT/DenseMap.h"
@@ -47,10 +46,10 @@ PassResult FunctionPass::runOnModule(Module *m) {
 }
 
 PassResult FunctionPass::runOnFunction(Function *fn) {
-  if (auto *mlFunc = dyn_cast<MLFunction>(fn))
-    return runOnMLFunction(mlFunc);
-  if (auto *cfgFunc = dyn_cast<CFGFunction>(fn))
-    return runOnCFGFunction(cfgFunc);
+  if (fn->isML())
+    return runOnMLFunction(fn);
+  if (fn->isCFG())
+    return runOnCFGFunction(fn);
 
   return success();
 }
index 07324ba7d528038fa74cb444a4801b5af2d24d39..a04cee7512dec1527b201c7f009131f056c22cfb 100644 (file)
@@ -35,8 +35,7 @@
 
 #include "mlir/Analysis/Dominance.h"
 #include "mlir/IR/Attributes.h"
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/Statements.h"
 #include "mlir/IR/StmtVisitor.h"
@@ -63,7 +62,8 @@ public:
   bool failure(const Twine &message, const BasicBlock &bb) {
     // Take the location information for the first instruction in the block.
     if (!bb.empty())
-      return failure(message, bb.front());
+      if (auto *op = dyn_cast<OperationStmt>(&bb.front()))
+        return failure(message, *op);
 
     // Worst case, fall back to using the function's location.
     return failure(message, fn);
@@ -224,7 +224,11 @@ bool CFGFuncVerifier::verifyBlock(const BasicBlock &block) {
   }
 
   for (auto &inst : block) {
-    if (verifyOperation(inst) || verifyInstOperands(inst))
+    if (auto *opInst = dyn_cast<OperationInst>(&inst))
+      if (verifyOperation(*opInst))
+        return true;
+
+    if (verifyInstOperands(inst))
       return true;
   }
   return false;
index fa04b1a3a85fb57b1617dca2b7cd13526797ae2c..4778564cb4debc75823a35a0e9ed06aeb18bb901 100644 (file)
@@ -24,9 +24,8 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/IntegerSet.h"
-#include "mlir/IR/MLFunction.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/Statements.h"
@@ -116,7 +115,7 @@ private:
 
   // Visit functions.
   void visitFunction(const Function *fn);
-  void visitExtFunction(const ExtFunction *fn);
+  void visitExtFunction(const Function *fn);
   void visitCFGFunction(const CFGFunction *fn);
   void visitMLFunction(const MLFunction *fn);
   void visitStatement(const Statement *stmt);
@@ -175,15 +174,19 @@ void ModuleState::visitOperation(const Operation *op) {
     visitAttribute(elt.second);
 }
 
-void ModuleState::visitExtFunction(const ExtFunction *fn) {
+void ModuleState::visitExtFunction(const Function *fn) {
   visitType(fn->getType());
 }
 
 void ModuleState::visitCFGFunction(const CFGFunction *fn) {
   visitType(fn->getType());
   for (auto &block : *fn) {
-    for (auto &op : block.getOperations()) {
-      visitOperation(&op);
+    for (auto &op : block.getStatements()) {
+      if (auto *opInst = dyn_cast<OperationInst>(&op))
+        visitOperation(opInst);
+      else {
+        llvm_unreachable("IfStmt/ForStmt in a CFGFunction isn't supported");
+      }
     }
   }
 }
@@ -238,11 +241,11 @@ void ModuleState::visitMLFunction(const MLFunction *fn) {
 void ModuleState::visitFunction(const Function *fn) {
   switch (fn->getKind()) {
   case Function::Kind::ExtFunc:
-    return visitExtFunction(cast<ExtFunction>(fn));
+    return visitExtFunction(fn);
   case Function::Kind::CFGFunc:
-    return visitCFGFunction(cast<CFGFunction>(fn));
+    return visitCFGFunction(fn);
   case Function::Kind::MLFunc:
-    return visitMLFunction(cast<MLFunction>(fn));
+    return visitMLFunction(fn);
   }
 }
 
@@ -274,9 +277,9 @@ public:
   void printAttribute(Attribute attr);
   void printType(Type type);
   void print(const Function *fn);
-  void print(const ExtFunction *fn);
-  void print(const CFGFunction *fn);
-  void print(const MLFunction *fn);
+  void printExt(const Function *fn);
+  void printCFG(const Function *fn);
+  void printML(const Function *fn);
 
   void printAffineMap(AffineMap map);
   void printAffineExpr(AffineExpr expr);
@@ -314,11 +317,11 @@ protected:
 void ModulePrinter::print(const Function *fn) {
   switch (fn->getKind()) {
   case Function::Kind::ExtFunc:
-    return print(cast<ExtFunction>(fn));
+    return printExt(fn);
   case Function::Kind::CFGFunc:
-    return print(cast<CFGFunction>(fn));
+    return printCFG(fn);
   case Function::Kind::MLFunc:
-    return print(cast<MLFunction>(fn));
+    return printML(fn);
   }
 }
 
@@ -927,7 +930,7 @@ void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
   os << '}';
 }
 
-void ModulePrinter::print(const ExtFunction *fn) {
+void ModulePrinter::printExt(const Function *fn) {
   os << "extfunc ";
   printFunctionSignature(fn);
   printFunctionAttributes(fn);
@@ -1001,17 +1004,6 @@ protected:
 
     if (specialNameBuffer.empty()) {
       switch (value->getKind()) {
-      case SSAValueKind::BBArgument:
-        // If this is an argument to the function, give it an 'arg' name.
-        if (auto *bb = cast<BBArgument>(value)->getOwner())
-          if (auto *fn = bb->getFunction())
-            if (&fn->front() == bb) {
-              specialName << "arg" << nextArgumentID++;
-              break;
-            }
-        // Otherwise number it normally.
-        valueIDs[value] = nextValueID++;
-        return;
       case SSAValueKind::BlockArgument:
         // If this is an argument to the function, give it an 'arg' name.
         if (auto *block = cast<BlockArgument>(value)->getOwner())
@@ -1023,7 +1015,6 @@ protected:
         // Otherwise number it normally.
         valueIDs[value] = nextValueID++;
         return;
-      case SSAValueKind::InstResult:
       case SSAValueKind::StmtResult:
         // This is an uninteresting result, give it a boring number and be
         // done with it.
@@ -1222,8 +1213,9 @@ void CFGFunctionPrinter::numberValuesInBlock(const BasicBlock *block) {
   for (auto &op : *block) {
     // We number instruction that have results, and we only number the first
     // result.
-    if (op.getNumResults() != 0)
-      numberValueID(op.getResult(0));
+    if (auto *opInst = dyn_cast<OperationInst>(&op))
+      if (opInst->getNumResults() != 0)
+        numberValueID(opInst->getResult(0));
   }
 
   // Terminators do not define values.
@@ -1278,7 +1270,7 @@ void CFGFunctionPrinter::print(const BasicBlock *block) {
   }
   os << '\n';
 
-  for (auto &inst : block->getOperations()) {
+  for (auto &inst : block->getStatements()) {
     os << "  ";
     print(&inst);
     os << '\n';
@@ -1290,7 +1282,9 @@ void CFGFunctionPrinter::print(const Instruction *inst) {
     os << "<<null instruction>>\n";
     return;
   }
-  printOperation(inst);
+  auto *opInst = dyn_cast<OperationInst>(inst);
+  assert(opInst && "IfStmt/ForStmt aren't supported in CFG functions yet");
+  printOperation(opInst);
 }
 
 // Print the operands from "container" to "os", followed by a colon and their
@@ -1317,7 +1311,7 @@ void CFGFunctionPrinter::printSuccessorAndUseList(const Operation *term,
   printBranchOperands(term->getSuccessorOperands(index));
 }
 
-void ModulePrinter::print(const CFGFunction *fn) {
+void ModulePrinter::printCFG(const Function *fn) {
   CFGFunctionPrinter(fn, *this).print();
 }
 
@@ -1530,7 +1524,7 @@ void MLFunctionPrinter::print(const IfStmt *stmt) {
   }
 }
 
-void ModulePrinter::print(const MLFunction *fn) {
+void ModulePrinter::printML(const Function *fn) {
   MLFunctionPrinter(fn, *this).print();
 }
 
@@ -1584,13 +1578,10 @@ void IntegerSet::print(raw_ostream &os) const {
 
 void SSAValue::print(raw_ostream &os) const {
   switch (getKind()) {
-  case SSAValueKind::BBArgument:
   case SSAValueKind::BlockArgument:
     // TODO: Improve this.
     os << "<block argument>\n";
     return;
-  case SSAValueKind::InstResult:
-    return getDefiningInst()->print(os);
   case SSAValueKind::StmtResult:
     return getDefiningStmt()->print(os);
   case SSAValueKind::ForStmt:
@@ -1601,13 +1592,20 @@ void SSAValue::print(raw_ostream &os) const {
 void SSAValue::dump() const { print(llvm::errs()); }
 
 void Instruction::print(raw_ostream &os) const {
-  if (!getFunction()) {
+  auto *function = getFunction();
+  if (!function) {
     os << "<<UNLINKED INSTRUCTION>>\n";
     return;
   }
-  ModuleState state(getFunction()->getContext());
-  ModulePrinter modulePrinter(os, state);
-  CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
+  if (function->isCFG()) {
+    ModuleState state(function->getContext());
+    ModulePrinter modulePrinter(os, state);
+    CFGFunctionPrinter(function, modulePrinter).print(this);
+  } else {
+    ModuleState state(function->getContext());
+    ModulePrinter modulePrinter(os, state);
+    MLFunctionPrinter(function, modulePrinter).print(this);
+  }
 }
 
 void Instruction::dump() const {
@@ -1616,19 +1614,27 @@ void Instruction::dump() const {
 }
 
 void BasicBlock::print(raw_ostream &os) const {
-  if (!getFunction()) {
+  auto *function = getFunction();
+  if (!function) {
     os << "<<UNLINKED BLOCK>>\n";
     return;
   }
-  ModuleState state(getFunction()->getContext());
-  ModulePrinter modulePrinter(os, state);
-  CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
+
+  if (function->isCFG()) {
+    ModuleState state(function->getContext());
+    ModulePrinter modulePrinter(os, state);
+    CFGFunctionPrinter(function, modulePrinter).print(this);
+  } else {
+    ModuleState state(function->getContext());
+    ModulePrinter modulePrinter(os, state);
+    MLFunctionPrinter(function, modulePrinter).print(this);
+  }
 }
 
 void BasicBlock::dump() const { print(llvm::errs()); }
 
 /// Print out the name of the basic block without printing its body.
-void BasicBlock::printAsOperand(raw_ostream &os, bool printType) {
+void StmtBlock::printAsOperand(raw_ostream &os, bool printType) {
   if (!getFunction()) {
     os << "<<UNLINKED BLOCK>>\n";
     return;
@@ -1638,29 +1644,6 @@ void BasicBlock::printAsOperand(raw_ostream &os, bool printType) {
   CFGFunctionPrinter(getFunction(), modulePrinter).printBBName(this);
 }
 
-void Statement::print(raw_ostream &os) const {
-  MLFunction *function = getFunction();
-  if (!function) {
-    os << "<<UNLINKED STATEMENT>>\n";
-    return;
-  }
-
-  ModuleState state(function->getContext());
-  ModulePrinter modulePrinter(os, state);
-  MLFunctionPrinter(function, modulePrinter).print(this);
-}
-
-void Statement::dump() const { print(llvm::errs()); }
-
-void StmtBlock::printBlock(raw_ostream &os) const {
-  const MLFunction *function = getFunction();
-  ModuleState state(function->getContext());
-  ModulePrinter modulePrinter(os, state);
-  MLFunctionPrinter(function, modulePrinter).print(this);
-}
-
-void StmtBlock::dumpBlock() const { printBlock(llvm::errs()); }
-
 void Function::print(raw_ostream &os) const {
   ModuleState state(getContext());
   ModulePrinter(os, state).print(this);
diff --git a/mlir/lib/IR/BasicBlock.cpp b/mlir/lib/IR/BasicBlock.cpp
deleted file mode 100644 (file)
index 8ba457a..0000000
+++ /dev/null
@@ -1,184 +0,0 @@
-//===- BasicBlock.cpp - MLIR BasicBlock Class -----------------------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/IR/BasicBlock.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/CFGFunction.h"
-using namespace mlir;
-
-BasicBlock::BasicBlock() {}
-
-BasicBlock::~BasicBlock() {
-  for (BBArgument *arg : arguments)
-    delete arg;
-  arguments.clear();
-}
-
-//===----------------------------------------------------------------------===//
-// Argument list management.
-//===----------------------------------------------------------------------===//
-
-BBArgument *BasicBlock::addArgument(Type type) {
-  auto *arg = new BBArgument(type, this);
-  arguments.push_back(arg);
-  return arg;
-}
-
-/// Add one argument to the argument list for each type specified in the list.
-auto BasicBlock::addArguments(ArrayRef<Type> types)
-    -> llvm::iterator_range<args_iterator> {
-  arguments.reserve(arguments.size() + types.size());
-  auto initialSize = arguments.size();
-  for (auto type : types) {
-    addArgument(type);
-  }
-  return {arguments.data() + initialSize, arguments.data() + arguments.size()};
-}
-
-void BasicBlock::eraseArgument(unsigned index) {
-  assert(index < arguments.size());
-
-  // Delete the argument.
-  delete arguments[index];
-  arguments.erase(arguments.begin() + index);
-
-  // Erase this argument from each of the predecessor's terminator.
-  for (auto predIt = pred_begin(), predE = pred_end(); predIt != predE;
-       ++predIt) {
-    auto *predTerminator = (*predIt)->getTerminator();
-    predTerminator->eraseSuccessorOperand(predIt.getSuccessorIndex(), index);
-  }
-}
-
-//===----------------------------------------------------------------------===//
-// Terminator management
-//===----------------------------------------------------------------------===//
-
-Instruction *BasicBlock::getTerminator() const {
-  if (empty())
-    return nullptr;
-
-  // Check if the last instruction is a terminator.
-  auto &backInst = operations.back();
-  return backInst.isTerminator() ? const_cast<Instruction *>(&backInst)
-                                 : nullptr;
-}
-
-/// Return true if this block has no predecessors.
-bool BasicBlock::hasNoPredecessors() const {
-  return pred_begin() == pred_end();
-}
-
-/// If this basic block has exactly one predecessor, return it.  Otherwise,
-/// return null.
-///
-/// Note that multiple edges from a single block (e.g. if you have a cond
-/// branch with the same block as the true/false destinations) is not
-/// considered to be a single predecessor.
-BasicBlock *BasicBlock::getSinglePredecessor() {
-  auto it = pred_begin();
-  if (it == pred_end())
-    return nullptr;
-  auto *firstPred = *it;
-  ++it;
-  return it == pred_end() ? firstPred : nullptr;
-}
-
-//===----------------------------------------------------------------------===//
-// ilist_traits for BasicBlock
-//===----------------------------------------------------------------------===//
-
-mlir::CFGFunction *
-llvm::ilist_traits<::mlir::BasicBlock>::getContainingFunction() {
-  size_t Offset(
-    size_t(&((CFGFunction *)nullptr->*CFGFunction::getSublistAccess(nullptr))));
-  iplist<BasicBlock> *Anchor(static_cast<iplist<BasicBlock> *>(this));
-  return reinterpret_cast<CFGFunction *>(reinterpret_cast<char *>(Anchor) -
-                                           Offset);
-}
-
-/// This is a trait method invoked when a basic block is added to a function.
-/// We keep the function pointer up to date.
-void llvm::ilist_traits<::mlir::BasicBlock>::
-addNodeToList(BasicBlock *block) {
-  assert(!block->function && "already in a function!");
-  block->function = getContainingFunction();
-}
-
-/// This is a trait method invoked when an instruction is removed from a
-/// function.  We keep the function pointer up to date.
-void llvm::ilist_traits<::mlir::BasicBlock>::
-removeNodeFromList(BasicBlock *block) {
-  assert(block->function && "not already in a function!");
-  block->function = nullptr;
-}
-
-/// This is a trait method invoked when an instruction is moved from one block
-/// to another.  We keep the block pointer up to date.
-void llvm::ilist_traits<::mlir::BasicBlock>::
-transferNodesFromList(ilist_traits<BasicBlock> &otherList,
-                      block_iterator first, block_iterator last) {
-  // If we are transferring instructions within the same function, the parent
-  // pointer doesn't need to be updated.
-  CFGFunction *curParent = getContainingFunction();
-  if (curParent == otherList.getContainingFunction())
-    return;
-
-  // Update the 'function' member of each BasicBlock.
-  for (; first != last; ++first)
-    first->function = curParent;
-}
-
-//===----------------------------------------------------------------------===//
-// Manipulators
-//===----------------------------------------------------------------------===//
-
-/// Unlink this BasicBlock from its CFGFunction and delete it.
-void BasicBlock::eraseFromFunction() {
-  assert(getFunction() && "BasicBlock has no parent");
-  getFunction()->getBlocks().erase(this);
-}
-
-/// Split the basic block into two basic blocks before the specified
-/// instruction or iterator.
-///
-/// Note that all instructions BEFORE the specified iterator stay as part of
-/// the original basic block, an unconditional branch is added to the original
-/// block (going to the new block), and the rest of the instructions in the
-/// original block are moved to the new BB, including the old terminator.  The
-/// newly formed BasicBlock is returned.
-///
-/// This function invalidates the specified iterator.
-BasicBlock *BasicBlock::splitBasicBlock(iterator splitBefore) {
-  // Start by creating a new basic block, and insert it immediate after this
-  // one in the containing function.
-  auto newBB = new BasicBlock();
-  getFunction()->getBlocks().insert(++CFGFunction::iterator(this), newBB);
-  auto branchLoc =
-      splitBefore == end() ? getTerminator()->getLoc() : splitBefore->getLoc();
-
-  // Move all of the operations from the split point to the end of the function
-  // into the new block.
-  newBB->getOperations().splice(newBB->end(), getOperations(), splitBefore,
-                                end());
-
-  // Create an unconditional branch to the new block, and move our terminator
-  // to the new block.
-  CFGFuncBuilder(this).create<BranchOp>(branchLoc, newBB);
-  return newBB;
-}
index 67b09c5aa6e317973a2e686cf20d4324b410daf4..0732448fb87847860f7a5f0bea396fd78d448cbb 100644 (file)
@@ -290,18 +290,18 @@ BasicBlock *CFGFuncBuilder::createBlock(BasicBlock *insertBefore) {
 }
 
 /// Create an operation given the fields represented as an OperationState.
-Instruction *CFGFuncBuilder::createOperation(const OperationState &state) {
+OperationStmt *CFGFuncBuilder::createOperation(const OperationState &state) {
   SmallVector<CFGValue *, 8> operands;
   operands.reserve(state.operands.size());
   // Allow null operands as they act as sentinal barriers between successor
   // operand lists.
   for (auto elt : state.operands)
-    operands.push_back(elt ? cast<CFGValue>(elt) : nullptr);
+    operands.push_back(cast_or_null<CFGValue>(elt));
 
   auto *op =
-      Instruction::create(state.location, state.name, operands, state.types,
-                          state.attributes, state.successors, context);
-  block->getOperations().insert(insertPoint, op);
+      OperationInst::create(state.location, state.name, operands, state.types,
+                            state.attributes, state.successors, context);
+  block->getStatements().insert(insertPoint, op);
   return op;
 }
 
@@ -318,7 +318,7 @@ OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) {
 
   auto *op =
       OperationStmt::create(state.location, state.name, operands, state.types,
-                            state.attributes, state.successorsS, context);
+                            state.attributes, state.successors, context);
   block->getStatements().insert(insertPoint, op);
   return op;
 }
index 04d032dc2ebcdb938eab053cdc9d033279ee376b..cdf98ca4bee958872f2af5efd6df3560ab3beedc 100644 (file)
@@ -190,7 +190,7 @@ void BranchOp::print(OpAsmPrinter *p) const {
 
 bool BranchOp::verify() const {
   // ML functions do not have branching terminators.
-  if (!isa<Instruction>(getOperation()))
+  if (getOperation()->getOperationFunction()->isML())
     return (emitOpError("cannot occur in a ML function"), true);
   return false;
 }
@@ -261,7 +261,7 @@ void CondBranchOp::print(OpAsmPrinter *p) const {
 
 bool CondBranchOp::verify() const {
   // ML functions do not have branching terminators.
-  if (!isa<Instruction>(getOperation()))
+  if (getOperation()->getOperationFunction()->isML())
     return (emitOpError("cannot occur in a ML function"), true);
   if (!getCondition()->getType().isInteger(1))
     return emitOpError("expected condition type was boolean (i1)");
@@ -472,11 +472,7 @@ void ReturnOp::print(OpAsmPrinter *p) const {
 }
 
 bool ReturnOp::verify() const {
-  const Function *function;
-  if (auto *stmt = dyn_cast<OperationStmt>(getOperation()))
-    function = stmt->getFunction();
-  else
-    function = cast<Instruction>(getOperation())->getFunction();
+  auto *function = cast<OperationStmt>(getOperation())->getFunction();
 
   // The operand number and types must match the function signature.
   const auto &results = function->getType().getResults();
index c21bd93aeadc7045a6139852e4182fd3572d14d2..62f1dca067ddaba526ccbb9171f63df95650fd38 100644 (file)
 // limitations under the License.
 // =============================================================================
 
+#include "mlir/IR/Function.h"
 #include "AttributeListStorage.h"
 #include "mlir/IR/Attributes.h"
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLFunction.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/StmtVisitor.h"
@@ -30,11 +29,29 @@ using namespace mlir;
 Function::Function(Kind kind, Location location, StringRef name,
                    FunctionType type, ArrayRef<NamedAttribute> attrs)
     : nameAndKind(Identifier::get(name, type.getContext()), kind),
-      location(location), type(type) {
+      location(location), type(type), blocks(this) {
   this->attrs = AttributeListStorage::get(attrs, getContext());
+
+  // Creating of an MLFunction automatically populates the entry block and
+  // arguments.
+  // TODO(clattner): Unify this behavior.
+  if (kind == Kind::MLFunc) {
+    // The body of an MLFunction always has one block.
+    auto *entry = new StmtBlock();
+    blocks.push_back(entry);
+
+    // Initialize the arguments.
+    entry->addArguments(type.getInputs());
+  }
 }
 
 Function::~Function() {
+  // Instructions may have cyclic references, which need to be dropped before we
+  // can start deleting them.
+  for (auto &bb : *this)
+    for (auto &inst : bb)
+      inst.dropAllReferences();
+
   // Clean up function attributes referring to this function.
   FunctionAttr::dropFunctionReference(this);
 }
@@ -48,21 +65,6 @@ ArrayRef<NamedAttribute> Function::getAttrs() const {
 
 MLIRContext *Function::getContext() const { return getType().getContext(); }
 
-/// Delete this object.
-void Function::destroy() {
-  switch (getKind()) {
-  case Kind::ExtFunc:
-    delete cast<ExtFunction>(this);
-    break;
-  case Kind::MLFunc:
-    delete cast<MLFunction>(this);
-    break;
-  case Kind::CFGFunc:
-    delete cast<CFGFunction>(this);
-    break;
-  }
-}
-
 Module *llvm::ilist_traits<Function>::getContainingModule() {
   size_t Offset(
       size_t(&((Module *)nullptr->*Module::getSublistAccess(nullptr))));
@@ -128,7 +130,7 @@ void llvm::ilist_traits<Function>::transferNodesFromList(
 }
 
 /// Unlink this function from its Module and delete it.
-void Function::eraseFromModule() {
+void Function::erase() {
   assert(getModule() && "Function has no parent");
   getModule()->getFunctions().erase(this);
 }
@@ -154,53 +156,11 @@ void Function::emitWarning(const Twine &message) const {
 bool Function::emitError(const Twine &message) const {
   return getContext()->emitError(getLoc(), message);
 }
-//===----------------------------------------------------------------------===//
-// ExtFunction implementation.
-//===----------------------------------------------------------------------===//
-
-ExtFunction::ExtFunction(Location location, StringRef name, FunctionType type,
-                         ArrayRef<NamedAttribute> attrs)
-    : Function(Kind::ExtFunc, location, name, type, attrs) {}
-
-//===----------------------------------------------------------------------===//
-// CFGFunction implementation.
-//===----------------------------------------------------------------------===//
-
-CFGFunction::CFGFunction(Location location, StringRef name, FunctionType type,
-                         ArrayRef<NamedAttribute> attrs)
-    : Function(Kind::CFGFunc, location, name, type, attrs) {}
-
-CFGFunction::~CFGFunction() {
-  // Instructions may have cyclic references, which need to be dropped before we
-  // can start deleting them.
-  for (auto &bb : *this)
-    for (auto &inst : bb)
-      inst.dropAllReferences();
-}
 
 //===----------------------------------------------------------------------===//
 // MLFunction implementation.
 //===----------------------------------------------------------------------===//
 
-MLFunction::MLFunction(Location location, StringRef name, FunctionType type,
-                       ArrayRef<NamedAttribute> attrs)
-    : Function(Kind::MLFunc, location, name, type, attrs), body(this) {
-
-  // The body of an MLFunction always has one block.
-  auto *entry = new StmtBlock();
-  body.push_back(entry);
-
-  // Initialize the arguments.
-  entry->addArguments(type.getInputs());
-}
-
-MLFunction::~MLFunction() {
-  // Explicitly erase statements instead of relying of 'StmtBlock' destructor
-  // since child statements need to be destroyed before function arguments
-  // are destroyed.
-  getBody()->clear();
-}
-
 const OperationStmt *MLFunction::getReturnStmt() const {
   return cast<OperationStmt>(&getBody()->back());
 }
diff --git a/mlir/lib/IR/Instructions.cpp b/mlir/lib/IR/Instructions.cpp
deleted file mode 100644 (file)
index 453b8b8..0000000
+++ /dev/null
@@ -1,330 +0,0 @@
-//===- Instructions.cpp - MLIR CFGFunction Instruction Classes ------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/IR/Instructions.h"
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLIRContext.h"
-using namespace mlir;
-
-/// Replace all uses of 'this' value with the new value, updating anything in
-/// the IR that uses 'this' to use the other value instead.  When this returns
-/// there are zero uses of 'this'.
-void IRObjectWithUseList::replaceAllUsesWith(IRObjectWithUseList *newValue) {
-  assert(this != newValue && "cannot RAUW a value with itself");
-  while (!use_empty()) {
-    use_begin()->set(newValue);
-  }
-}
-
-/// Return the result number of this result.
-unsigned InstResult::getResultNumber() const {
-  // Results are always stored consecutively, so use pointer subtraction to
-  // figure out what number this is.
-  return this - &getOwner()->getInstResults()[0];
-}
-
-/// Return which operand this is in the operand list.
-template <> unsigned InstOperand::getOperandNumber() const {
-  return this - &getOwner()->getInstOperands()[0];
-}
-
-/// Return which operand this is in the operand list.
-template <> unsigned BasicBlockOperand::getOperandNumber() const {
-  return this - &getOwner()->getBasicBlockOperands()[0];
-}
-
-//===----------------------------------------------------------------------===//
-// Instruction
-//===----------------------------------------------------------------------===//
-
-void Instruction::setSuccessor(BasicBlock *block, unsigned index) {
-  assert(index < getNumSuccessors());
-  getBasicBlockOperands()[index].set(block);
-}
-
-/// Create a new Instruction with the specified fields.
-Instruction *Instruction::create(Location location, OperationName name,
-                                 ArrayRef<CFGValue *> operands,
-                                 ArrayRef<Type> resultTypes,
-                                 ArrayRef<NamedAttribute> attributes,
-                                 ArrayRef<BasicBlock *> successors,
-                                 MLIRContext *context) {
-  unsigned numSuccessors = successors.size();
-  // Input operands are nullptr-separated for each successors in the case of
-  // terminators, the nullptr aren't actually stored.
-  unsigned numOperands = operands.size() - llvm::count(operands, nullptr);
-
-  auto byteSize =
-      totalSizeToAlloc<InstResult, InstOperand, BasicBlockOperand, unsigned>(
-          resultTypes.size(), numOperands, numSuccessors, numSuccessors);
-  void *rawMem = malloc(byteSize);
-
-  // Initialize the Instruction part of the instruction.
-  auto inst = ::new (rawMem)
-      Instruction(location, name, resultTypes.size(), numOperands,
-                  numSuccessors, attributes, context);
-
-  // Initialize the results and operands.
-  auto instResults = inst->getInstResults();
-  for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
-    new (&instResults[i]) InstResult(resultTypes[i], inst);
-
-  // instOperandIt is the iterator in the tail-allocated memory for the
-  // operands, and operandIt is the iterator in the input operands array.
-  // operandIt skips nullptr in the input that acts as sentinels marking the
-  // separation between multiple basic block successors for terminators.
-  auto instOperandIt = inst->getInstOperands().begin();
-  unsigned operandIt = 0, operandE = operands.size();
-  for (; operandIt != operandE; ++operandIt) {
-    if (!operands[operandIt])
-      break;
-    new (instOperandIt++) InstOperand(inst, operands[operandIt]);
-  }
-
-  // Check to see if a sentinel (nullptr) operand was encountered.
-  unsigned currentSuccNum = 0;
-  if (operandIt != operandE) {
-    assert(inst->isTerminator() &&
-           "Sentinel operand found in non terminator operand list.");
-    auto instBlockOperands = inst->getBasicBlockOperands();
-    unsigned *succOperandCountIt = inst->getTrailingObjects<unsigned>();
-    unsigned *succOperandCountE = succOperandCountIt + numSuccessors;
-    (void)succOperandCountE;
-
-    for (; operandIt != operandE; ++operandIt) {
-      // If we encounter a sentinal branch to the next operand update the count
-      // variable.
-      if (!operands[operandIt]) {
-        assert(currentSuccNum < numSuccessors);
-
-        // After the first iteration update the successor operand count
-        // variable.
-        if (currentSuccNum != 0) {
-          ++succOperandCountIt;
-          assert(succOperandCountIt != succOperandCountE &&
-                 "More sentinal operands than successors.");
-        }
-
-        new (&instBlockOperands[currentSuccNum])
-            BasicBlockOperand(inst, successors[currentSuccNum]);
-        *succOperandCountIt = 0;
-        ++currentSuccNum;
-        continue;
-      }
-      new (instOperandIt++) InstOperand(inst, operands[operandIt]);
-      ++(*succOperandCountIt);
-    }
-  }
-
-  // Verify that the amount of sentinal operands is equivalent to the number of
-  // successors.
-  assert(currentSuccNum == numSuccessors);
-  return inst;
-}
-
-Instruction *Instruction::clone() const {
-  SmallVector<CFGValue *, 8> operands;
-  SmallVector<Type, 8> resultTypes;
-  SmallVector<BasicBlock *, 1> successors;
-
-  // Put together the results.
-  for (auto *result : getResults())
-    resultTypes.push_back(result->getType());
-
-  // If the instruction is a terminator the successor and non-successor operand
-  // lists are interleaved with sentinal(nullptr) operands.
-  if (isTerminator()) {
-    // To interleave the operand lists we iterate in reverse and insert the
-    // operands in-place.
-    operands.resize(getNumOperands() + getNumSuccessors());
-    successors.resize(getNumSuccessors());
-    int cloneOperandIt = operands.size() - 1, operandIt = getNumOperands() - 1;
-    for (int succIt = getNumSuccessors() - 1, succE = 0; succIt >= succE;
-         --succIt) {
-      successors[succIt] = const_cast<BasicBlock *>(getSuccessor(succIt));
-
-      // Add the successor operands in-place in reverse order.
-      for (unsigned i = 0, e = getNumSuccessorOperands(succIt); i != e;
-           ++i, --cloneOperandIt, --operandIt) {
-        operands[cloneOperandIt] =
-            const_cast<CFGValue *>(getOperand(operandIt));
-      }
-
-      // Add a null operand for the barrier.
-      operands[cloneOperandIt--] = nullptr;
-    }
-
-    // Add the rest of the non-successor operands.
-    for (; cloneOperandIt >= 0; --cloneOperandIt, --operandIt)
-      operands[cloneOperandIt] = const_cast<CFGValue *>(getOperand(operandIt));
-    // For non terminators we can simply add each of the instructions in place.
-  } else {
-    for (auto *operand : getOperands())
-      operands.push_back(const_cast<CFGValue *>(operand));
-  }
-
-  return create(getLoc(), getName(), operands, resultTypes, getAttrs(),
-                successors, getContext());
-}
-
-Instruction::Instruction(Location location, OperationName name,
-                         unsigned numResults, unsigned numOperands,
-                         unsigned numSuccessors,
-                         ArrayRef<NamedAttribute> attributes,
-                         MLIRContext *context)
-    : Operation(/*isInstruction=*/true, name, attributes, context),
-      IROperandOwner(IROperandOwner::Kind::Instruction, location),
-      numResults(numResults), numSuccs(numSuccessors),
-      numOperands(numOperands) {}
-
-Instruction::~Instruction() {
-  assert(block == nullptr && "instruction destroyed but still in a block");
-
-  // Explicitly run the destructors for the results
-  for (auto &result : getInstResults())
-    result.~InstResult();
-
-  // Explicitly run the destructors for the operands.
-  for (auto &result : getInstOperands())
-    result.~InstOperand();
-
-  // Explicitly run the destructors for the successors.
-  if (isTerminator())
-    for (auto &successor : getBasicBlockOperands())
-      successor.~BasicBlockOperand();
-}
-
-void Instruction::eraseOperand(unsigned index) {
-  assert(index < getNumOperands());
-  auto Operands = getInstOperands();
-  // Shift all operands down by 1.
-  std::rotate(&Operands[index], &Operands[index + 1],
-              &Operands[numOperands - 1]);
-  --numOperands;
-  Operands[getNumOperands()].~InstOperand();
-}
-
-/// Destroy this instruction.
-void Instruction::destroy() {
-  this->~Instruction();
-  free(this);
-}
-
-CFGFunction *Instruction::getFunction() {
-  auto *block = getBlock();
-  return block ? block->getFunction() : nullptr;
-}
-
-/// This drops all operand uses from this instruction, which is an essential
-/// step in breaking cyclic dependences between references when they are to
-/// be deleted.
-void Instruction::dropAllReferences() {
-  for (auto &op : getInstOperands())
-    op.drop();
-
-  if (isTerminator())
-    for (auto &dest : getBasicBlockOperands())
-      dest.drop();
-}
-
-/// Emit a note about this instruction, reporting up to any diagnostic
-/// handlers that may be listening.
-void Instruction::emitNote(const Twine &message) const {
-  getContext()->emitDiagnostic(getLoc(), message,
-                               MLIRContext::DiagnosticKind::Note);
-}
-
-/// Emit a warning about this operation, reporting up to any diagnostic
-/// handlers that may be listening.
-void Instruction::emitWarning(const Twine &message) const {
-  getContext()->emitDiagnostic(getLoc(), message,
-                               MLIRContext::DiagnosticKind::Warning);
-}
-
-/// Emit an error about fatal conditions with this operation, reporting up to
-/// any diagnostic handlers that may be listening.  This function always
-/// returns true.  NOTE: This may terminate the containing application, only use
-/// when the IR is in an inconsistent state.
-bool Instruction::emitError(const Twine &message) const {
-  return getContext()->emitError(getLoc(), message);
-}
-
-void llvm::ilist_traits<::mlir::Instruction>::deleteNode(Instruction *inst) {
-  inst->destroy();
-}
-
-mlir::BasicBlock *
-llvm::ilist_traits<::mlir::Instruction>::getContainingBlock() {
-  size_t Offset(
-      size_t(&((BasicBlock *)nullptr->*BasicBlock::getSublistAccess(nullptr))));
-  iplist<Instruction> *Anchor(static_cast<iplist<Instruction> *>(this));
-  return reinterpret_cast<BasicBlock *>(reinterpret_cast<char *>(Anchor) -
-                                        Offset);
-}
-
-/// This is a trait method invoked when an instruction is added to a block.  We
-/// keep the block pointer up to date.
-void llvm::ilist_traits<::mlir::Instruction>::addNodeToList(Instruction *inst) {
-  assert(!inst->getBlock() && "already in a basic block!");
-  inst->block = getContainingBlock();
-}
-
-/// This is a trait method invoked when an instruction is removed from a block.
-/// We keep the block pointer up to date.
-void llvm::ilist_traits<::mlir::Instruction>::removeNodeFromList(
-    Instruction *inst) {
-  assert(inst->block && "not already in a basic block!");
-  inst->block = nullptr;
-}
-
-/// This is a trait method invoked when an instruction is moved from one block
-/// to another.  We keep the block pointer up to date.
-void llvm::ilist_traits<::mlir::Instruction>::transferNodesFromList(
-    ilist_traits<Instruction> &otherList, instr_iterator first,
-    instr_iterator last) {
-  // If we are transferring instructions within the same basic block, the block
-  // pointer doesn't need to be updated.
-  BasicBlock *curParent = getContainingBlock();
-  if (curParent == otherList.getContainingBlock())
-    return;
-
-  // Update the 'block' member of each instruction.
-  for (; first != last; ++first)
-    first->block = curParent;
-}
-
-/// Unlink this instruction from its BasicBlock and delete it.
-void Instruction::erase() {
-  assert(getBlock() && "Instruction has no parent");
-  getBlock()->getOperations().erase(this);
-}
-
-/// Unlink this operation instruction from its current basic block and insert
-/// it right before `existingInst` which may be in the same or another block
-/// in the same function.
-void Instruction::moveBefore(Instruction *existingInst) {
-  assert(existingInst && "Cannot move before a null instruction");
-  return moveBefore(existingInst->getBlock(), existingInst->getIterator());
-}
-
-/// Unlink this operation instruction from its current basic block and insert
-/// it right before `iterator` in the specified basic block.
-void Instruction::moveBefore(BasicBlock *block,
-                             llvm::iplist<Instruction>::iterator iterator) {
-  block->getOperations().splice(iterator, getBlock()->getOperations(),
-                                getIterator());
-}
index d50e7070f7089eab0ae9558a1c6806ed0a7db76b..3a537d03e8f9486ebbda782da9acdc87a00834c8 100644 (file)
 
 #include "mlir/IR/Operation.h"
 #include "AttributeListStorage.h"
-#include "mlir/IR/CFGFunction.h"
 #include "mlir/IR/Dialect.h"
-#include "mlir/IR/Instructions.h"
-#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
@@ -74,115 +72,77 @@ Operation::~Operation() {}
 
 /// Return the context this operation is associated with.
 MLIRContext *Operation::getContext() const {
-  if (auto *inst = llvm::dyn_cast<Instruction>(this))
-    return inst->getContext();
   return llvm::cast<OperationStmt>(this)->getContext();
 }
 
 /// The source location the operation was defined or derived from.  Note that
 /// it is possible for this pointer to be null.
 Location Operation::getLoc() const {
-  if (auto *inst = llvm::dyn_cast<Instruction>(this))
-    return inst->getLoc();
   return llvm::cast<OperationStmt>(this)->getLoc();
 }
 
 /// Set the source location the operation was defined or derived from.
 void Operation::setLoc(Location loc) {
-  if (auto *inst = llvm::dyn_cast<Instruction>(this))
-    inst->setLoc(loc);
-  else
-    llvm::cast<OperationStmt>(this)->setLoc(loc);
+  llvm::cast<OperationStmt>(this)->setLoc(loc);
 }
 
 /// Return the function this operation is defined in.
 Function *Operation::getOperationFunction() {
-  if (auto *inst = llvm::dyn_cast<Instruction>(this))
-    return inst->getFunction();
   return llvm::cast<OperationStmt>(this)->getFunction();
 }
 
 /// Return the number of operands this operation has.
 unsigned Operation::getNumOperands() const {
-  if (auto *inst = llvm::dyn_cast<Instruction>(this))
-    return inst->getNumOperands();
-
   return llvm::cast<OperationStmt>(this)->getNumOperands();
 }
 
 SSAValue *Operation::getOperand(unsigned idx) {
-  if (auto *inst = llvm::dyn_cast<Instruction>(this))
-    return inst->getOperand(idx);
-
   return llvm::cast<OperationStmt>(this)->getOperand(idx);
 }
 
 void Operation::setOperand(unsigned idx, SSAValue *value) {
-  if (auto *inst = llvm::dyn_cast<Instruction>(this)) {
-    inst->setOperand(idx, llvm::cast<CFGValue>(value));
-  } else {
-    auto *stmt = llvm::cast<OperationStmt>(this);
-    stmt->setOperand(idx, llvm::cast<MLValue>(value));
-  }
+  auto *stmt = llvm::cast<OperationStmt>(this);
+  stmt->setOperand(idx, llvm::cast<MLValue>(value));
 }
 
 /// Return the number of results this operation has.
 unsigned Operation::getNumResults() const {
-  if (auto *inst = llvm::dyn_cast<Instruction>(this))
-    return inst->getNumResults();
-
   return llvm::cast<OperationStmt>(this)->getNumResults();
 }
 
 /// Return the indicated result.
 SSAValue *Operation::getResult(unsigned idx) {
-  if (auto *inst = llvm::dyn_cast<Instruction>(this))
-    return inst->getResult(idx);
-
   return llvm::cast<OperationStmt>(this)->getResult(idx);
 }
 
 unsigned Operation::getNumSuccessors() const {
   assert(isTerminator() && "Only terminators have successors.");
-  if (auto *inst = llvm::dyn_cast<Instruction>(this))
-    return inst->getNumSuccessors();
-
   return llvm::cast<OperationStmt>(this)->getNumSuccessors();
 }
 
 unsigned Operation::getNumSuccessorOperands(unsigned index) const {
   assert(isTerminator() && "Only terminators have successors.");
-
-  if (auto *inst = llvm::dyn_cast<Instruction>(this))
-    return inst->getNumSuccessorOperands(index);
-
   return llvm::cast<OperationStmt>(this)->getNumSuccessorOperands(index);
 }
 BasicBlock *Operation::getSuccessor(unsigned index) {
-  assert(isTerminator() && "Only terminators have successors.");
-  assert(llvm::isa<Instruction>(this) &&
-         "Only instructions have basic block successors.");
-  return llvm::cast<Instruction>(this)->getSuccessor(index);
+  assert(isTerminator() && "Only terminators have successors");
+  return llvm::cast<OperationStmt>(this)->getSuccessor(index);
 }
 void Operation::setSuccessor(BasicBlock *block, unsigned index) {
-  assert(isTerminator() && "Only terminators have successors.");
-  assert(llvm::isa<Instruction>(this) &&
-         "Only instructions have basic block successors.");
-  llvm::cast<Instruction>(this)->setSuccessor(block, index);
+  assert(isTerminator() && "Only terminators have successors");
+  llvm::cast<OperationStmt>(this)->setSuccessor(block, index);
 }
 
 void Operation::eraseSuccessorOperand(unsigned succIndex, unsigned opIndex) {
-  assert(isTerminator() && "Only terminators have successors.");
-  assert(llvm::isa<Instruction>(this) && "Only instructions have successors.");
-  return llvm::cast<Instruction>(this)->eraseSuccessorOperand(succIndex,
-                                                              opIndex);
+  assert(isTerminator() && "Only terminators have successors");
+  return llvm::cast<OperationStmt>(this)->eraseSuccessorOperand(succIndex,
+                                                                opIndex);
 }
 auto Operation::getSuccessorOperands(unsigned index) const
     -> llvm::iterator_range<const_operand_iterator> {
   assert(isTerminator() && "Only terminators have successors.");
-  assert(llvm::isa<Instruction>(this) && "Only instructions have successors.");
   unsigned succOperandIndex =
-      llvm::cast<Instruction>(this)->getSuccessorOperandIndex(index);
+      llvm::cast<OperationStmt>(this)->getSuccessorOperandIndex(index);
   return {const_operand_iterator(this, succOperandIndex),
           const_operand_iterator(this, succOperandIndex +
                                            getNumSuccessorOperands(index))};
@@ -190,9 +150,8 @@ auto Operation::getSuccessorOperands(unsigned index) const
 auto Operation::getSuccessorOperands(unsigned index)
     -> llvm::iterator_range<operand_iterator> {
   assert(isTerminator() && "Only terminators have successors.");
-  assert(llvm::isa<Instruction>(this) && "Only instructions have successors.");
   unsigned succOperandIndex =
-      llvm::cast<Instruction>(this)->getSuccessorOperandIndex(index);
+      llvm::cast<OperationStmt>(this)->getSuccessorOperandIndex(index);
   return {operand_iterator(this, succOperandIndex),
           operand_iterator(this,
                            succOperandIndex + getNumSuccessorOperands(index))};
@@ -207,8 +166,6 @@ bool Operation::use_empty() const {
 }
 
 void Operation::moveBefore(Operation *existingOp) {
-  if (auto *inst = llvm::dyn_cast<Instruction>(this))
-    return inst->moveBefore(llvm::cast<Instruction>(existingOp));
   return llvm::cast<OperationStmt>(this)->moveBefore(
       llvm::cast<OperationStmt>(existingOp));
 }
@@ -288,8 +245,6 @@ bool Operation::emitOpError(const Twine &message) const {
 
 /// Remove this operation from its parent block and delete it.
 void Operation::erase() {
-  if (auto *inst = llvm::dyn_cast<Instruction>(this))
-    return inst->erase();
   return llvm::cast<OperationStmt>(this)->erase();
 }
 
@@ -319,14 +274,10 @@ bool Operation::constantFold(ArrayRef<Attribute> operands,
 }
 
 void Operation::print(raw_ostream &os) const {
-  if (auto *inst = llvm::dyn_cast<Instruction>(this))
-    return inst->print(os);
   return llvm::cast<OperationStmt>(this)->print(os);
 }
 
 void Operation::dump() const {
-  if (auto *inst = llvm::dyn_cast<Instruction>(this))
-    return inst->dump();
   return llvm::cast<OperationStmt>(this)->dump();
 }
 
@@ -347,11 +298,10 @@ Operation *
 llvm::cast_convert_val<mlir::Operation, mlir::IROperandOwner *,
                        mlir::IROperandOwner *>::doit(const mlir::IROperandOwner
                                                          *value) {
+  // TODO(clattner): obsolete this.
   const Operation *op;
-  if (auto *ptr = dyn_cast<OperationStmt>(value))
-    op = ptr;
-  else
-    op = cast<Instruction>(value);
+  auto *ptr = cast<OperationStmt>(value);
+  op = ptr;
   return const_cast<Operation *>(op);
 }
 
@@ -579,12 +529,13 @@ static bool verifyTerminatorSuccessors(const Operation *op) {
 
 bool OpTrait::impl::verifyIsTerminator(const Operation *op) {
   // Verify that the operation is at the end of the respective parent block.
-  if (auto *stmt = dyn_cast<OperationStmt>(op)) {
+  if (op->getOperationFunction()->isML()) {
+    auto *stmt = cast<OperationStmt>(op);
     StmtBlock *block = stmt->getBlock();
     if (!block || block->getContainingStmt() || &block->back() != stmt)
       return op->emitOpError("must be the last statement in the ML function");
   } else {
-    const Instruction *inst = cast<Instruction>(op);
+    auto *inst = cast<OperationInst>(op);
     const BasicBlock *block = inst->getBlock();
     if (!block || &block->back() != inst)
       return op->emitOpError(
index 32365d67f349f128cd3c21f4612330adb536fdc1..9a26149ea1dfcad38229fedbd6e35850103217c2 100644 (file)
 // =============================================================================
 
 #include "mlir/IR/SSAValue.h"
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/Instructions.h"
-#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/Statements.h"
 
 using namespace mlir;
 
 /// If this value is the result of an Instruction, return the instruction
 /// that defines it.
-Instruction *SSAValue::getDefiningInst() {
+OperationInst *SSAValue::getDefiningInst() {
   if (auto *result = dyn_cast<InstResult>(this))
     return result->getOwner();
   return nullptr;
@@ -50,10 +48,6 @@ Operation *SSAValue::getDefiningOperation() {
 /// Return the function that this SSAValue is defined in.
 Function *SSAValue::getFunction() {
   switch (getKind()) {
-  case SSAValueKind::BBArgument:
-    return cast<BBArgument>(this)->getFunction();
-  case SSAValueKind::InstResult:
-    return getDefiningInst()->getFunction();
   case SSAValueKind::BlockArgument:
     return cast<BlockArgument>(this)->getFunction();
   case SSAValueKind::StmtResult:
@@ -67,6 +61,16 @@ Function *SSAValue::getFunction() {
 // IROperandOwner implementation.
 //===----------------------------------------------------------------------===//
 
+/// Replace all uses of 'this' value with the new value, updating anything in
+/// the IR that uses 'this' to use the other value instead.  When this returns
+/// there are zero uses of 'this'.
+void IRObjectWithUseList::replaceAllUsesWith(IRObjectWithUseList *newValue) {
+  assert(this != newValue && "cannot RAUW a value with itself");
+  while (!use_empty()) {
+    use_begin()->set(newValue);
+  }
+}
+
 /// Return the context this operation is associated with.
 MLIRContext *IROperandOwner::getContext() const {
   switch (getKind()) {
@@ -85,26 +89,6 @@ MLIRContext *IROperandOwner::getContext() const {
   }
 }
 
-//===----------------------------------------------------------------------===//
-// CFGValue implementation.
-//===----------------------------------------------------------------------===//
-
-/// Return the function that this CFGValue is defined in.
-CFGFunction *CFGValue::getFunction() {
-  return cast<CFGFunction>(static_cast<SSAValue *>(this)->getFunction());
-}
-
-//===----------------------------------------------------------------------===//
-// BBArgument implementation.
-//===----------------------------------------------------------------------===//
-
-/// Return the function that this argument is defined in.
-CFGFunction *BBArgument::getFunction() {
-  if (auto *owner = getOwner())
-    return owner->getFunction();
-  return nullptr;
-}
-
 //===----------------------------------------------------------------------===//
 // MLValue implementation.
 //===----------------------------------------------------------------------===//
index 8922aaf72e0866971b7e4d59324574ddfbddc217..2a47eb56a285a30cadadd796af492ecdd89c825b 100644 (file)
@@ -18,8 +18,8 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/IntegerSet.h"
-#include "mlir/IR/MLFunction.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Statements.h"
 #include "mlir/IR/StmtVisitor.h"
@@ -177,6 +177,13 @@ bool Statement::emitError(const Twine &message) const {
   return getContext()->emitError(getLoc(), message);
 }
 
+// Returns whether the Statement is a terminator.
+bool Statement::isTerminator() const {
+  if (auto *op = dyn_cast<OperationStmt>(this))
+    return op->isTerminator();
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // ilist_traits for Statement
 //===----------------------------------------------------------------------===//
@@ -246,6 +253,18 @@ void Statement::moveBefore(StmtBlock *block,
                                 getIterator());
 }
 
+/// This drops all operand uses from this instruction, which is an essential
+/// step in breaking cyclic dependences between references when they are to
+/// be deleted.
+void Statement::dropAllReferences() {
+  for (auto &op : getStmtOperands())
+    op.drop();
+
+  if (isTerminator())
+    for (auto &dest : cast<OperationInst>(this)->getBlockOperands())
+      dest.drop();
+}
+
 //===----------------------------------------------------------------------===//
 // OperationStmt
 //===----------------------------------------------------------------------===//
@@ -258,14 +277,19 @@ OperationStmt *OperationStmt::create(Location location, OperationName name,
                                      ArrayRef<StmtBlock *> successors,
                                      MLIRContext *context) {
   unsigned numSuccessors = successors.size();
+
+  // Input operands are nullptr-separated for each successors in the case of
+  // terminators, the nullptr aren't actually stored.
+  unsigned numOperands = operands.size() - numSuccessors;
+
   auto byteSize =
       totalSizeToAlloc<StmtResult, StmtBlockOperand, unsigned, StmtOperand>(
-          resultTypes.size(), numSuccessors, numSuccessors, operands.size());
+          resultTypes.size(), numSuccessors, numSuccessors, numOperands);
   void *rawMem = malloc(byteSize);
 
   // Initialize the OperationStmt part of the statement.
   auto stmt = ::new (rawMem)
-      OperationStmt(location, name, operands.size(), resultTypes.size(),
+      OperationStmt(location, name, numOperands, resultTypes.size(),
                     numSuccessors, attributes, context);
 
   // Initialize the results and operands.
@@ -292,7 +316,6 @@ OperationStmt *OperationStmt::create(Location location, OperationName name,
     // Verify that the amount of sentinal operands is equivalent to the number
     // of successors.
     assert(currentSuccNum == numSuccessors);
-
     return stmt;
   }
 
@@ -350,6 +373,11 @@ OperationStmt::~OperationStmt() {
 
   for (auto &result : getStmtResults())
     result.~StmtResult();
+
+  // Explicitly run the destructors for the successors.
+  if (isTerminator())
+    for (auto &successor : getBlockOperands())
+      successor.~StmtBlockOperand();
 }
 
 void OperationStmt::destroy() {
@@ -373,6 +401,21 @@ MLIRContext *OperationStmt::getContext() const {
 
 bool OperationStmt::isReturn() const { return isa<ReturnOp>(); }
 
+void OperationStmt::setSuccessor(BasicBlock *block, unsigned index) {
+  assert(index < getNumSuccessors());
+  getBlockOperands()[index].set(block);
+}
+
+void OperationInst::eraseOperand(unsigned index) {
+  assert(index < getNumOperands());
+  auto Operands = getStmtOperands();
+  // Shift all operands down by 1.
+  std::rotate(&Operands[index], &Operands[index + 1],
+              &Operands[numOperands - 1]);
+  --numOperands;
+  Operands[getNumOperands()].~StmtOperand();
+}
+
 //===----------------------------------------------------------------------===//
 // ForStmt
 //===----------------------------------------------------------------------===//
@@ -671,3 +714,8 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
 
   return newIf;
 }
+
+Statement *Statement::clone(MLIRContext *context) const {
+  DenseMap<const MLValue *, MLValue *> operandMap;
+  return clone(operandMap, context);
+}
index 1c2c77d2da3b61e992a50a328a47ac60c82f9bc3..996375c302697522f76ee1042c606c60df0d8a95 100644 (file)
@@ -16,8 +16,8 @@
 // =============================================================================
 
 #include "mlir/IR/StmtBlock.h"
-#include "mlir/IR/MLFunction.h"
-#include "mlir/IR/Statements.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
 using namespace mlir;
 
 StmtBlock::~StmtBlock() {
@@ -139,20 +139,63 @@ StmtBlock *StmtBlock::getSinglePredecessor() {
   return it == pred_end() ? firstPred : nullptr;
 }
 
+//===----------------------------------------------------------------------===//
+// Other
+//===----------------------------------------------------------------------===//
+
+/// Unlink this BasicBlock from its CFGFunction and delete it.
+void BasicBlock::eraseFromFunction() {
+  assert(getFunction() && "BasicBlock has no parent");
+  getFunction()->getBlocks().erase(this);
+}
+
+/// Split the basic block into two basic blocks before the specified
+/// instruction or iterator.
+///
+/// Note that all instructions BEFORE the specified iterator stay as part of
+/// the original basic block, an unconditional branch is added to the original
+/// block (going to the new block), and the rest of the instructions in the
+/// original block are moved to the new BB, including the old terminator.  The
+/// newly formed BasicBlock is returned.
+///
+/// This function invalidates the specified iterator.
+BasicBlock *BasicBlock::splitBasicBlock(iterator splitBefore) {
+  // Start by creating a new basic block, and insert it immediate after this
+  // one in the containing function.
+  auto newBB = new BasicBlock();
+  getFunction()->getBlocks().insert(++CFGFunction::iterator(this), newBB);
+  auto branchLoc =
+      splitBefore == end() ? getTerminator()->getLoc() : splitBefore->getLoc();
+
+  // Move all of the operations from the split point to the end of the function
+  // into the new block.
+  newBB->getStatements().splice(newBB->end(), getStatements(), splitBefore,
+                                end());
+
+  // Create an unconditional branch to the new block, and move our terminator
+  // to the new block.
+  CFGFuncBuilder(this).create<BranchOp>(branchLoc, newBB);
+  return newBB;
+}
+
 //===----------------------------------------------------------------------===//
 // StmtBlockList
 //===----------------------------------------------------------------------===//
 
-StmtBlockList::StmtBlockList(MLFunction *container) : container(container) {}
+StmtBlockList::StmtBlockList(Function *container) : container(container) {}
 
 StmtBlockList::StmtBlockList(Statement *container) : container(container) {}
 
+CFGFunction *StmtBlockList::getFunction() {
+  return dyn_cast_or_null<CFGFunction>(getContainingFunction());
+}
+
 Statement *StmtBlockList::getContainingStmt() {
   return container.dyn_cast<Statement *>();
 }
 
-MLFunction *StmtBlockList::getContainingFunction() {
-  return container.dyn_cast<MLFunction *>();
+Function *StmtBlockList::getContainingFunction() {
+  return container.dyn_cast<Function *>();
 }
 
 StmtBlockList *llvm::ilist_traits<::mlir::StmtBlock>::getContainingBlockList() {
index 38ce5325eb8ae43d0b01d01d66213ce5e821b707..d58d687ee0c9571a4c06533565916436eff7bb07 100644 (file)
@@ -28,7 +28,6 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/Location.h"
-#include "mlir/IR/MLFunction.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/OpImplementation.h"
@@ -68,7 +67,7 @@ public:
   ~ParserState() {
     // Destroy the forward references upon error.
     for (auto forwardRef : functionForwardRefs)
-      forwardRef.second->destroy();
+      delete forwardRef.second;
     functionForwardRefs.clear();
   }
 
@@ -785,8 +784,9 @@ Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
   if (!function) {
     auto &entry = state.functionForwardRefs[name];
     if (!entry)
-      entry = new ExtFunction(getEncodedSourceLocation(nameLoc), name, type,
-                              /*attrs=*/{});
+      entry = new Function(Function::Kind::ExtFunc,
+                           getEncodedSourceLocation(nameLoc), name, type,
+                           /*attrs=*/{});
     function = entry;
   }
 
@@ -1959,10 +1959,10 @@ SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
   // cannot be created through normal user input, allowing us to distinguish
   // them.
   auto name = OperationName("placeholder", getContext());
-  auto *inst = Instruction::create(getEncodedSourceLocation(loc), name,
-                                   /*operands=*/{}, type,
-                                   /*attributes=*/{},
-                                   /*successors=*/{}, getContext());
+  auto *inst = OperationInst::create(getEncodedSourceLocation(loc), name,
+                                     /*operands=*/{}, type,
+                                     /*attributes=*/{},
+                                     /*successors=*/{}, getContext());
   forwardReferencePlaceholders[inst->getResult(0)] = loc;
   return inst->getResult(0);
 }
@@ -3383,7 +3383,8 @@ ParseResult ModuleParser::parseExtFunc() {
 
   // Okay, the external function definition was parsed correctly.
   auto *function =
-      new ExtFunction(getEncodedSourceLocation(loc), name, type, attrs);
+      new Function(Function::Kind::ExtFunc, getEncodedSourceLocation(loc), name,
+                   type, attrs);
   getModule()->getFunctions().push_back(function);
 
   // Verify no name collision / redefinition.
@@ -3416,7 +3417,8 @@ ParseResult ModuleParser::parseCFGFunc() {
   // Okay, the CFG function signature was parsed correctly, create the
   // function.
   auto *function =
-      new CFGFunction(getEncodedSourceLocation(loc), name, type, attrs);
+      new Function(Function::Kind::CFGFunc, getEncodedSourceLocation(loc), name,
+                   type, attrs);
   getModule()->getFunctions().push_back(function);
 
   // Verify no name collision / redefinition.
@@ -3451,8 +3453,8 @@ ParseResult ModuleParser::parseMLFunc() {
 
   // Okay, the ML function signature was parsed correctly, create the
   // function.
-  auto *function =
-      new MLFunction(getEncodedSourceLocation(loc), name, type, attrs);
+  auto *function = new Function(
+      Function::Kind::MLFunc, getEncodedSourceLocation(loc), name, type, attrs);
   getModule()->getFunctions().push_back(function);
 
   // Verify no name collision / redefinition.
@@ -3504,7 +3506,7 @@ ParseResult ModuleParser::finalizeModule() {
   // Now that all references to the forward definition placeholders are
   // resolved, we can deallocate the placeholders.
   for (auto forwardRef : getState().functionForwardRefs)
-    forwardRef.second->destroy();
+    delete forwardRef.second;
   getState().functionForwardRefs.clear();
   return ParseSuccess;
 }
index 14615cd1e65db97a53d52778cb8487735a311d28..5c325dbd95d0e12402b14466fb8683cda7a5af02 100644 (file)
@@ -21,9 +21,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/CFGFunction.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
+#include "mlir/IR/Statements.h"
 #include "mlir/StandardOps/StandardOps.h"
 #include "mlir/SuperVectorOps/SuperVectorOps.h"
 #include "mlir/Support/FileUtilities.h"
@@ -55,7 +55,7 @@ private:
   bool convertBasicBlock(const BasicBlock &bb, bool ignoreArguments = false);
   bool convertCFGFunction(const CFGFunction &cfgFunc, llvm::Function &llvmFunc);
   bool convertFunctions(const Module &mlirModule, llvm::Module &llvmModule);
-  bool convertInstruction(const Instruction &inst);
+  bool convertInstruction(const OperationInst &inst);
 
   void connectPHINodes(const CFGFunction &cfgFunc);
 
@@ -517,7 +517,7 @@ static llvm::CmpInst::Predicate getLLVMCmpPredicate(CmpIPredicate p) {
 // FIXME(zinenko): this should eventually become a separate MLIR pass that
 // converts MLIR standard operations into LLVM IR dialect; the translation in
 // that case would become a simple 1:1 instruction and value remapping.
-bool ModuleLowerer::convertInstruction(const Instruction &inst) {
+bool ModuleLowerer::convertInstruction(const OperationInst &inst) {
   if (auto op = inst.dyn_cast<AddIOp>())
     return valueMapping[op->getResult()] =
                builder.CreateAdd(valueMapping[op->getOperand(0)],
@@ -766,7 +766,11 @@ bool ModuleLowerer::convertBasicBlock(const BasicBlock &bb,
 
   // Traverse instructions.
   for (const auto &inst : bb) {
-    if (convertInstruction(inst))
+    auto *op = dyn_cast<OperationInst>(&inst);
+    if (!op)
+      return inst.emitError("unsupported operation");
+
+    if (convertInstruction(*op))
       return true;
   }
 
@@ -779,7 +783,7 @@ static const SSAValue *getPHISourceValue(const BasicBlock *current,
                                          const BasicBlock *pred,
                                          unsigned numArguments,
                                          unsigned index) {
-  const Instruction &terminator = *pred->getTerminator();
+  auto &terminator = *pred->getTerminator();
   if (terminator.isa<BranchOp>()) {
     return terminator.getOperand(index);
   }
@@ -849,7 +853,7 @@ bool ModuleLowerer::convertFunctions(const Module &mlirModule,
   // call graph with cycles.  We don't expect MLFunctions here.
   for (const Function &function : mlirModule) {
     const Function *functionPtr = &function;
-    if (!isa<ExtFunction>(functionPtr) && !isa<CFGFunction>(functionPtr))
+    if (functionPtr->isML())
       continue;
     llvm::Constant *llvmFuncCst = llvmModule.getOrInsertFunction(
         function.getName(), convertFunctionType(function.getType()));
@@ -860,21 +864,20 @@ bool ModuleLowerer::convertFunctions(const Module &mlirModule,
   // Convert CFG functions.
   for (const Function &function : mlirModule) {
     const Function *functionPtr = &function;
-    auto cfgFunction = dyn_cast<CFGFunction>(functionPtr);
-    if (!cfgFunction)
+    if (!functionPtr->isCFG())
       continue;
-    llvm::Function *llvmFunc = functionMapping[cfgFunction];
+    llvm::Function *llvmFunc = functionMapping[functionPtr];
 
     // Add function arguments to the value remapping table.  In CFGFunction,
     // arguments of the first block are those of the function.
-    assert(!cfgFunction->getBlocks().empty() &&
+    assert(!functionPtr->getBlocks().empty() &&
            "expected at least one basic block in a CFGFunction");
-    const BasicBlock &firstBlock = *cfgFunction->begin();
+    const BasicBlock &firstBlock = *functionPtr->begin();
     for (auto arg : llvm::enumerate(llvmFunc->args())) {
       valueMapping[firstBlock.getArgument(arg.index())] = &arg.value();
     }
 
-    if (convertCFGFunction(*cfgFunction, *functionMapping[cfgFunction]))
+    if (convertCFGFunction(*functionPtr, *functionMapping[functionPtr]))
       return true;
   }
   return false;
index dd157066d96073c51b514438c521c8d00d6d6c3a..575ae2e1c9b69946ea8b7bd63eebc85d34f4b3b8 100644 (file)
@@ -23,7 +23,7 @@
 #include "mlir/Analysis/Dominance.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
-#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/StmtVisitor.h"
 #include "mlir/Pass.h"
 #include "mlir/Support/Functional.h"
@@ -201,7 +201,8 @@ struct CFGCSE : public CSEImpl {
 
   void simplifyBasicBlock(BasicBlock *bb) {
     for (auto &i : *bb)
-      simplifyOperation(&i);
+      if (auto *opInst = dyn_cast<OperationInst>(&i))
+        simplifyOperation(opInst);
   }
 };
 
index 15a5db15d739b83e4ff0bc2a9217a77c8387c763..d4a50a05989f278101c84681b102729d970594b0 100644 (file)
@@ -16,7 +16,7 @@
 // =============================================================================
 
 #include "mlir/IR/Builders.h"
-#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/StmtVisitor.h"
 #include "mlir/Pass.h"
 #include "mlir/Transforms/Passes.h"
@@ -110,20 +110,22 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
 
   for (auto &bb : *f) {
     for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) {
-      auto &inst = *instIt++;
+      auto *inst = dyn_cast<OperationInst>(&*instIt++);
+      if (!inst)
+        continue;
 
       auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
-        builder.setInsertionPoint(&inst);
-        return builder.create<ConstantOp>(inst.getLoc(), value, type);
+        builder.setInsertionPoint(inst);
+        return builder.create<ConstantOp>(inst->getLoc(), value, type);
       };
 
-      if (!foldOperation(&inst, existingConstants, constantFactory)) {
+      if (!foldOperation(inst, existingConstants, constantFactory)) {
         // At this point the operation is dead, remove it.
         // TODO: This is assuming that all constant foldable operations have no
         // side effects.  When we have side effect modeling, we should verify
         // that the operation is effect-free before we remove it.  Until then
         // this is close enough.
-        inst.erase();
+        inst->erase();
       }
     }
   }
index 4fafff5132299e4bd8d49133f719c94346e978f3..4423891a4bf289b559d5cc41d0cc360c0be78b6d 100644 (file)
@@ -21,8 +21,6 @@
 
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLFunction.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/StmtVisitor.h"
@@ -552,8 +550,8 @@ PassResult ModuleConverter::runOnModule(Module *m) {
 
 void ModuleConverter::convertMLFunctions() {
   for (Function &fn : *module) {
-    if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
-      generatedFuncs[mlFunc] = convert(mlFunc);
+    if (fn.isML())
+      generatedFuncs[&fn] = convert(&fn);
   }
 }
 
@@ -562,8 +560,8 @@ CFGFunction *ModuleConverter::convert(MLFunction *mlFunc) {
   // Use the same name as for ML function; do not add the converted function to
   // the module yet to avoid collision.
   auto name = mlFunc->getName().str();
-  auto *cfgFunc = new CFGFunction(mlFunc->getLoc(), name, mlFunc->getType(),
-                                  mlFunc->getAttrs());
+  auto *cfgFunc = new Function(Function::Kind::CFGFunc, mlFunc->getLoc(), name,
+                               mlFunc->getType(), mlFunc->getAttrs());
 
   // Generates the body of the CFG function.
   return FunctionConverter(cfgFunc).convert(mlFunc);
@@ -580,14 +578,13 @@ void ModuleConverter::replaceReferences() {
   // functions.
   llvm::DenseMap<Attribute, FunctionAttr> remappingTable;
   for (const Function &fn : *module) {
-    const auto *mlFunc = dyn_cast<MLFunction>(&fn);
-    if (!mlFunc)
+    if (!fn.isML())
       continue;
-    CFGFunction *convertedFunc = generatedFuncs.lookup(mlFunc);
+    CFGFunction *convertedFunc = generatedFuncs.lookup(&fn);
     assert(convertedFunc && "ML function was not converted");
 
     MLIRContext *context = module->getContext();
-    auto mlFuncAttr = FunctionAttr::get(mlFunc, context);
+    auto mlFuncAttr = FunctionAttr::get(&fn, context);
     auto cfgFuncAttr = FunctionAttr::get(convertedFunc, module->getContext());
     remappingTable.insert({mlFuncAttr, cfgFuncAttr});
   }
@@ -607,12 +604,11 @@ void ModuleConverter::replaceReferences() {
 static inline void replaceMLFunctionAttr(
     Operation &op, Identifier name, const Function *func,
     const llvm::DenseMap<MLFunction *, CFGFunction *> &generatedFuncs) {
-  const auto *mlFunc = dyn_cast<MLFunction>(func);
-  if (!mlFunc)
+  if (!func->isML())
     return;
 
   Builder b(op.getContext());
-  auto cfgFunc = generatedFuncs.lookup(mlFunc);
+  auto *cfgFunc = generatedFuncs.lookup(func);
   op.setAttr(name, b.getFunctionAttr(cfgFunc));
 }
 
index 0dc405132e1e7883984ac3d6b7bbdf7914c817f5..e8a2af54b8e41a596f75ff64f86cd21096dfdaf6 100644 (file)
@@ -23,8 +23,6 @@
 //
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLFunction.h"
 #include "mlir/Pass.h"
 #include "mlir/Transforms/LoweringUtils.h"
 #include "mlir/Transforms/Passes.h"
@@ -58,10 +56,11 @@ PassResult LowerAffineApply::runOnCFGFunction(CFGFunction *f) {
     // Handle iterators with care because we erase in the same loop.
     // In particular, step to the next element before erasing the current one.
     for (auto it = bb.begin(); it != bb.end();) {
-      Instruction &inst = *it;
-      OpPointer<AffineApplyOp> affineApplyOp = inst.dyn_cast<AffineApplyOp>();
-      ++it;
+      auto *inst = dyn_cast<OperationInst>(&*it++);
+      if (!inst)
+        continue;
 
+      auto affineApplyOp = inst->dyn_cast<AffineApplyOp>();
       if (!affineApplyOp)
         continue;
       if (expandAffineApply(&*affineApplyOp))
index a862ec4471a5b95cad7ade1a3f6b1147012e794e..048e26ae1150b0f3f222632f9a9688269a45214e 100644 (file)
@@ -20,7 +20,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/AffineStructures.h"
-#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/StmtVisitor.h"
 #include "mlir/Pass.h"
 #include "mlir/Transforms/Passes.h"
index fbde1fd1692638ad70ed195438c4fb3e83fd00e4..0af7e52b5b11f973c38b19d9de0d880388c8b9a8 100644 (file)
@@ -172,9 +172,10 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
       // TODO: If we make terminators into Operations then we could turn this
       // into a nice Operation::moveBefore(Operation*) method.  We just need the
       // guarantee that a block is non-empty.
-      if (auto *cfgFunc = dyn_cast<CFGFunction>(currentFunction)) {
-        auto &entryBB = cfgFunc->front();
-        cast<Instruction>(op)->moveBefore(&entryBB, entryBB.begin());
+      // TODO(clattner): This can all be simplified away now.
+      if (currentFunction->isCFG()) {
+        auto &entryBB = currentFunction->front();
+        cast<OperationInst>(op)->moveBefore(&entryBB, entryBB.begin());
       } else {
         auto *mlFunc = cast<MLFunction>(currentFunction);
         cast<OperationStmt>(op)->moveBefore(mlFunc->getBody(),
@@ -315,7 +316,7 @@ static void processCFGFunction(CFGFunction *fn,
 
     void setInsertionPoint(Operation *op) override {
       // Any new operations should be added before this instruction.
-      builder.setInsertionPoint(cast<Instruction>(op));
+      builder.setInsertionPoint(cast<OperationInst>(op));
     }
 
   private:
@@ -325,7 +326,8 @@ static void processCFGFunction(CFGFunction *fn,
   GreedyPatternRewriteDriver driver(std::move(patterns));
   for (auto &bb : *fn)
     for (auto &op : bb)
-      driver.addToWorklist(&op);
+      if (auto *opInst = dyn_cast<OperationStmt>(&op))
+        driver.addToWorklist(opInst);
 
   CFGFuncBuilder cfgBuilder(fn);
   CFGFuncRewriter rewriter(driver, cfgBuilder);
@@ -337,9 +339,8 @@ static void processCFGFunction(CFGFunction *fn,
 ///
 void mlir::applyPatternsGreedily(Function *fn,
                                  OwningRewritePatternList &&patterns) {
-  if (auto *cfg = dyn_cast<CFGFunction>(fn)) {
-    processCFGFunction(cfg, std::move(patterns));
-  } else {
-    processMLFunction(cast<MLFunction>(fn), std::move(patterns));
-  }
+  if (fn->isCFG())
+    processCFGFunction(fn, std::move(patterns));
+  else if (fn->isML())
+    processMLFunction(fn, std::move(patterns));
 }
index 8d375c42ca3c14cdb3eeb806b0939b29de67eb78..2818e8c2e4f86083ddfb02193f6c1ff201bcfacf 100644 (file)
@@ -438,18 +438,18 @@ void mlir::remapFunctionAttrs(
 void mlir::remapFunctionAttrs(
     Function &fn, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
   // Look at all instructions in a CFGFunction.
-  if (auto *cfgFn = dyn_cast<CFGFunction>(&fn)) {
-    for (auto &bb : *cfgFn) {
+  if (fn.isCFG()) {
+    for (auto &bb : fn.getBlockList()) {
       for (auto &inst : bb) {
-        remapFunctionAttrs(inst, remappingTable);
+        if (auto *op = dyn_cast<OperationInst>(&inst))
+          remapFunctionAttrs(*op, remappingTable);
       }
     }
     return;
   }
 
-  // Otherwise, look at MLFunctions.  We ignore ExtFunctions.
-  auto *mlFn = dyn_cast<MLFunction>(&fn);
-  if (!mlFn)
+  // Otherwise, look at MLFunctions.  We ignore external functions.
+  if (!fn.isML())
     return;
 
   struct MLFnWalker : public StmtWalker<MLFnWalker> {
@@ -462,7 +462,7 @@ void mlir::remapFunctionAttrs(
     const DenseMap<Attribute, FunctionAttr> &remappingTable;
   };
 
-  MLFnWalker(remappingTable).walk(mlFn);
+  MLFnWalker(remappingTable).walk(&fn);
 }
 
 void mlir::remapFunctionAttrs(
index 07e1ebb15185bc47e69f6aa8e8865f5a26baa4b1..f255c932282324c45b24bbc6c42827aa526f4b9b 100644 (file)
@@ -23,9 +23,8 @@
 
 #include "mlir/Analysis/Passes.h"
 #include "mlir/IR/Attributes.h"
-#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/Location.h"
-#include "mlir/IR/MLFunction.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
 #include "mlir/Parser.h"