/// Return true if instruction A dominates instruction B.
bool dominates(const SSAValue *a, const Instruction *b) {
- return a->getDefiningInst() == b || properlyDominates(a, b);
+ return (Instruction *)a->getDefiningInst() == b || properlyDominates(a, b);
}
// dominates/properlyDominates for basic blocks.
+++ /dev/null
-//===- BasicBlock.h - MLIR BasicBlock Class ---------------------*- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#ifndef MLIR_IR_BASICBLOCK_H
-#define MLIR_IR_BASICBLOCK_H
-
-#include "mlir/IR/Instructions.h"
-
-namespace mlir {
-class BBArgument;
-class CFGFunction;
-template <typename BlockType> class PredecessorIterator;
-template <typename BlockType> class SuccessorIterator;
-
-/// Each basic block in a CFG function contains a list of basic block arguments,
-/// normal instructions, and a terminator instruction.
-///
-/// Basic blocks form a graph (the CFG) which can be traversed through
-/// predecessor and successor edges.
-class BasicBlock
- : public IRObjectWithUseList,
- public llvm::ilist_node_with_parent<BasicBlock, CFGFunction> {
-public:
- explicit BasicBlock();
- ~BasicBlock();
-
- /// Return the function that a BasicBlock is part of.
- CFGFunction *getFunction() { return function; }
- const CFGFunction *getFunction() const { return function; }
-
- /// Return the function that a BasicBlock is part of.
- const CFGFunction *getParent() const { return function; }
- CFGFunction *getParent() { return function; }
-
- //===--------------------------------------------------------------------===//
- // Block argument management
- //===--------------------------------------------------------------------===//
-
- // This is the list of arguments to the block.
- using BBArgListType = ArrayRef<BBArgument *>;
- BBArgListType getArguments() const { return arguments; }
-
- using args_iterator = BBArgListType::iterator;
- using reverse_args_iterator = BBArgListType::reverse_iterator;
- args_iterator args_begin() const { return getArguments().begin(); }
- args_iterator args_end() const { return getArguments().end(); }
- reverse_args_iterator args_rbegin() const { return getArguments().rbegin(); }
- reverse_args_iterator args_rend() const { return getArguments().rend(); }
-
- bool args_empty() const { return arguments.empty(); }
-
- /// Add one value to the argument list.
- BBArgument *addArgument(Type type);
-
- /// Add one argument to the argument list for each type specified in the list.
- llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type> types);
-
- /// Erase the argument at 'index' and remove it from the argument list.
- void eraseArgument(unsigned index);
-
- unsigned getNumArguments() const { return arguments.size(); }
- BBArgument *getArgument(unsigned i) { return arguments[i]; }
- const BBArgument *getArgument(unsigned i) const { return arguments[i]; }
-
- //===--------------------------------------------------------------------===//
- // Operation list management
- //===--------------------------------------------------------------------===//
-
- /// This is the list of operations in the block.
- using OperationListType = llvm::iplist<Instruction>;
- OperationListType &getOperations() { return operations; }
- const OperationListType &getOperations() const { return operations; }
-
- // Iteration over the operations in the block.
- using iterator = OperationListType::iterator;
- using const_iterator = OperationListType::const_iterator;
- using reverse_iterator = OperationListType::reverse_iterator;
- using const_reverse_iterator = OperationListType::const_reverse_iterator;
-
- iterator begin() { return operations.begin(); }
- iterator end() { return operations.end(); }
- const_iterator begin() const { return operations.begin(); }
- const_iterator end() const { return operations.end(); }
- reverse_iterator rbegin() { return operations.rbegin(); }
- reverse_iterator rend() { return operations.rend(); }
- const_reverse_iterator rbegin() const { return operations.rbegin(); }
- const_reverse_iterator rend() const { return operations.rend(); }
-
- bool empty() const { return operations.empty(); }
- void push_back(Instruction *inst) { operations.push_back(inst); }
- void push_front(Instruction *inst) { operations.push_front(inst); }
-
- Instruction &back() { return operations.back(); }
- const Instruction &back() const {
- return const_cast<BasicBlock *>(this)->back();
- }
-
- Instruction &front() { return operations.front(); }
- const Instruction &front() const {
- return const_cast<BasicBlock*>(this)->front();
- }
-
- //===--------------------------------------------------------------------===//
- // Terminator management
- //===--------------------------------------------------------------------===//
-
- /// Get the terminator instruction of this block, or null if the block is
- /// malformed.
- Instruction *getTerminator() const;
-
- //===--------------------------------------------------------------------===//
- // Predecessors and successors.
- //===--------------------------------------------------------------------===//
-
- // Predecessor iteration.
- using const_pred_iterator = PredecessorIterator<const BasicBlock>;
- const_pred_iterator pred_begin() const;
- const_pred_iterator pred_end() const;
- llvm::iterator_range<const_pred_iterator> getPredecessors() const;
-
- using pred_iterator = PredecessorIterator<BasicBlock>;
- pred_iterator pred_begin();
- pred_iterator pred_end();
- llvm::iterator_range<pred_iterator> getPredecessors();
-
- /// Return true if this block has no predecessors.
- bool hasNoPredecessors() const;
-
- /// If this basic block has exactly one predecessor, return it. Otherwise,
- /// return null.
- ///
- /// Note that if a block has duplicate predecessors from a single block (e.g.
- /// if you have a conditional branch with the same block as the true/false
- /// destinations) is not considered to be a single predecessor.
- BasicBlock *getSinglePredecessor();
-
- const BasicBlock *getSinglePredecessor() const {
- return const_cast<BasicBlock *>(this)->getSinglePredecessor();
- }
-
- // Indexed successor access.
- unsigned getNumSuccessors() const {
- return getTerminator()->getNumSuccessors();
- }
- const BasicBlock *getSuccessor(unsigned i) const {
- return const_cast<BasicBlock *>(this)->getSuccessor(i);
- }
- BasicBlock *getSuccessor(unsigned i) {
- return getTerminator()->getSuccessor(i);
- }
-
- // Successor iteration.
- using const_succ_iterator = SuccessorIterator<const BasicBlock>;
- const_succ_iterator succ_begin() const;
- const_succ_iterator succ_end() const;
- llvm::iterator_range<const_succ_iterator> getSuccessors() const;
-
- using succ_iterator = SuccessorIterator<BasicBlock>;
- succ_iterator succ_begin();
- succ_iterator succ_end();
- llvm::iterator_range<succ_iterator> getSuccessors();
-
- //===--------------------------------------------------------------------===//
- // Manipulators
- //===--------------------------------------------------------------------===//
-
- /// Unlink this BasicBlock from its CFGFunction and delete it.
- void eraseFromFunction();
-
- /// Split the basic block into two basic blocks before the specified
- /// instruction or iterator.
- ///
- /// Note that all instructions BEFORE the specified iterator stay as part of
- /// the original basic block, an unconditional branch is added to the original
- /// block (going to the new block), and the rest of the instructions in the
- /// original block are moved to the new BB, including the old terminator. The
- /// newly formed BasicBlock is returned.
- ///
- /// This function invalidates the specified iterator.
- BasicBlock *splitBasicBlock(iterator splitBefore);
- BasicBlock *splitBasicBlock(Instruction *splitBeforeInst) {
- return splitBasicBlock(iterator(splitBeforeInst));
- }
-
- void print(raw_ostream &os) const;
- void dump() const;
-
- /// Print out the name of the basic block without printing its body.
- /// NOTE: The printType argument is ignored. We keep it for compatibility
- /// with LLVM dominator machinery that expects it to exist.
- void printAsOperand(raw_ostream &os, bool printType = true);
-
- /// getSublistAccess() - Returns pointer to member of operation list
- static OperationListType BasicBlock::*getSublistAccess(Instruction *) {
- return &BasicBlock::operations;
- }
-
-private:
- CFGFunction *function = nullptr;
-
- /// This is the list of operations in the block.
- OperationListType operations;
-
- /// This is the list of arguments to the block.
- std::vector<BBArgument *> arguments;
-
- BasicBlock(const BasicBlock&) = delete;
- void operator=(const BasicBlock&) = delete;
-
- friend struct llvm::ilist_traits<BasicBlock>;
-};
-
-//===----------------------------------------------------------------------===//
-// Predecessors
-//===----------------------------------------------------------------------===//
-
-/// Implement a predecessor iterator as a forward iterator. This works by
-/// walking the use lists of basic blocks. The entries on this list are the
-/// BasicBlockOperands that are embedded into terminator instructions. From the
-/// operand, we can get the terminator that contains it, and it's parent block
-/// is the predecessor.
-template <typename BlockType>
-class PredecessorIterator
- : public llvm::iterator_facade_base<PredecessorIterator<BlockType>,
- std::forward_iterator_tag,
- BlockType *> {
-public:
- PredecessorIterator(BasicBlockOperand *firstOperand)
- : bbUseIterator(firstOperand) {}
-
- PredecessorIterator &operator=(const PredecessorIterator &rhs) {
- bbUseIterator = rhs.bbUseIterator;
- }
-
- bool operator==(const PredecessorIterator &rhs) const {
- return bbUseIterator == rhs.bbUseIterator;
- }
-
- BlockType *operator*() const {
- // The use iterator points to an operand of a terminator. The predecessor
- // we return is the basic block that that terminator is embedded into.
- return bbUseIterator.getUser()->getBlock();
- }
-
- PredecessorIterator &operator++() {
- ++bbUseIterator;
- return *this;
- }
-
- /// Get the successor number in the predecessor terminator.
- unsigned getSuccessorIndex() const {
- return bbUseIterator->getOperandNumber();
- }
-
-private:
- using BBUseIterator = SSAValueUseIterator<BasicBlockOperand, Instruction>;
- BBUseIterator bbUseIterator;
-};
-
-inline auto BasicBlock::pred_begin() const -> const_pred_iterator {
- return const_pred_iterator((BasicBlockOperand *)getFirstUse());
-}
-
-inline auto BasicBlock::pred_end() const -> const_pred_iterator {
- return const_pred_iterator(nullptr);
-}
-
-inline auto BasicBlock::getPredecessors() const
- -> llvm::iterator_range<const_pred_iterator> {
- return {pred_begin(), pred_end()};
-}
-
-inline auto BasicBlock::pred_begin() -> pred_iterator {
- return pred_iterator((BasicBlockOperand *)getFirstUse());
-}
-
-inline auto BasicBlock::pred_end() -> pred_iterator {
- return pred_iterator(nullptr);
-}
-
-inline auto BasicBlock::getPredecessors()
- -> llvm::iterator_range<pred_iterator> {
- return {pred_begin(), pred_end()};
-}
-
-//===----------------------------------------------------------------------===//
-// Successors
-//===----------------------------------------------------------------------===//
-
-/// This template implments the successor iterators for basic block.
-template <typename BlockType>
-class SuccessorIterator final
- : public IndexedAccessorIterator<SuccessorIterator<BlockType>, BlockType,
- BlockType> {
-public:
- /// Initializes the result iterator to the specified index.
- SuccessorIterator(BlockType *object, unsigned index)
- : IndexedAccessorIterator<SuccessorIterator<BlockType>, BlockType,
- BlockType>(object, index) {}
-
- SuccessorIterator(const SuccessorIterator &other)
- : SuccessorIterator(other.object, other.index) {}
-
- /// Support converting to the const variant. This will be a no-op for const
- /// variant.
- operator SuccessorIterator<const BlockType>() const {
- return SuccessorIterator<const BlockType>(this->object, this->index);
- }
-
- BlockType *operator*() const {
- return this->object->getSuccessor(this->index);
- }
-
- /// Get the successor number in the terminator.
- unsigned getSuccessorIndex() const { return this->index; }
-};
-
-inline auto BasicBlock::succ_begin() const -> const_succ_iterator {
- return const_succ_iterator(this, 0);
-}
-
-inline auto BasicBlock::succ_end() const -> const_succ_iterator {
- return const_succ_iterator(this, getNumSuccessors());
-}
-
-inline auto BasicBlock::getSuccessors() const
- -> llvm::iterator_range<const_succ_iterator> {
- return {succ_begin(), succ_end()};
-}
-
-inline auto BasicBlock::succ_begin() -> succ_iterator {
- return succ_iterator(this, 0);
-}
-
-inline auto BasicBlock::succ_end() -> succ_iterator {
- return succ_iterator(this, getNumSuccessors());
-}
-
-inline auto BasicBlock::getSuccessors() -> llvm::iterator_range<succ_iterator> {
- return {succ_begin(), succ_end()};
-}
-
-} // end namespace mlir
-
-//===----------------------------------------------------------------------===//
-// ilist_traits for BasicBlock
-//===----------------------------------------------------------------------===//
-
-namespace llvm {
-
-template <>
-struct ilist_traits<::mlir::BasicBlock>
- : public ilist_alloc_traits<::mlir::BasicBlock> {
- using BasicBlock = ::mlir::BasicBlock;
- using block_iterator = simple_ilist<BasicBlock>::iterator;
-
- void addNodeToList(BasicBlock *block);
- void removeNodeFromList(BasicBlock *block);
- void transferNodesFromList(ilist_traits<BasicBlock> &otherList,
- block_iterator first, block_iterator last);
-private:
- mlir::CFGFunction *getContainingFunction();
-};
-} // end namespace llvm
-
-
-#endif // MLIR_IR_BASICBLOCK_H
#ifndef MLIR_IR_BUILDERS_H
#define MLIR_IR_BUILDERS_H
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/Statements.h"
namespace mlir {
}
void insert(Instruction *opInst) {
- block->getOperations().insert(insertPoint, opInst);
+ block->getStatements().insert(insertPoint, opInst);
}
/// Add new basic block and set the insertion point to the end of it. If an
BasicBlock *createBlock(BasicBlock *insertBefore = nullptr);
/// Create an operation given the fields represented as an OperationState.
- Instruction *createOperation(const OperationState &state);
+ OperationStmt *createOperation(const OperationState &state);
/// Create operation of specific op type at the current insertion point
/// without verifying to see if it is valid.
return OpPointer<OpTy>();
}
- Instruction *cloneOperation(const Instruction &srcOpInst) {
- auto *op = srcOpInst.clone();
+ OperationStmt *cloneOperation(const OperationStmt &srcOpInst) {
+ auto *op = cast<OperationStmt>(srcOpInst.clone(getContext()));
insert(op);
return op;
}
: Builder(mlFuncBuilder.getContext()), builder(mlFuncBuilder),
kind(Function::Kind::MLFunc) {}
FuncBuilder(Operation *op) : Builder(op->getContext()) {
- if (auto *inst = dyn_cast<Instruction>(op)) {
- builder = builderUnion(inst);
+ if (op->getOperationFunction()->isCFG()) {
+ builder = builderUnion(CFGFuncBuilder(cast<OperationInst>(op)));
kind = Function::Kind::CFGFunc;
} else {
- builder = builderUnion(cast<OperationStmt>(op));
+ builder = builderUnion(MLFuncBuilder(cast<OperationStmt>(op)));
kind = Function::Kind::MLFunc;
}
}
/// OperationStmt when building a ML function.
void setInsertionPoint(Operation *op) {
if (kind == Function::Kind::CFGFunc)
- builder.cfg.setInsertionPoint(cast<Instruction>(op));
+ builder.cfg.setInsertionPoint(cast<OperationStmt>(op));
else
builder.ml.setInsertionPoint(cast<OperationStmt>(op));
}
union builderUnion {
builderUnion(CFGFuncBuilder cfg) : cfg(cfg) {}
builderUnion(MLFuncBuilder ml) : ml(ml) {}
- builderUnion(Instruction *op) : cfg(op) {}
- builderUnion(OperationStmt *op) : ml(op) {}
// Default initializer to allow deferring initialization of member.
builderUnion() {}
#include "mlir/IR/OpDefinition.h"
namespace mlir {
-class BasicBlock;
class Builder;
class MLValue;
+++ /dev/null
-//===- CFGFunction.h - MLIR CFGFunction Class -------------------*- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#ifndef MLIR_IR_CFGFUNCTION_H
-#define MLIR_IR_CFGFUNCTION_H
-
-#include "mlir/IR/BasicBlock.h"
-#include "mlir/IR/Function.h"
-
-namespace mlir {
-
-// This kind of function is defined in terms of a "Control Flow Graph" of basic
-// blocks, each of which includes instructions.
-class CFGFunction : public Function {
-public:
- CFGFunction(Location location, StringRef name, FunctionType type,
- ArrayRef<NamedAttribute> attrs = {});
-
- ~CFGFunction();
-
- //===--------------------------------------------------------------------===//
- // BasicBlock list management
- //===--------------------------------------------------------------------===//
-
- /// This is the list of blocks in the function.
- using BasicBlockListType = llvm::iplist<BasicBlock>;
- BasicBlockListType &getBlocks() { return blocks; }
- const BasicBlockListType &getBlocks() const { return blocks; }
-
- // Iteration over the block in the function.
- using iterator = BasicBlockListType::iterator;
- using const_iterator = BasicBlockListType::const_iterator;
- using reverse_iterator = BasicBlockListType::reverse_iterator;
- using const_reverse_iterator = BasicBlockListType::const_reverse_iterator;
-
- iterator begin() { return blocks.begin(); }
- iterator end() { return blocks.end(); }
- const_iterator begin() const { return blocks.begin(); }
- const_iterator end() const { return blocks.end(); }
- reverse_iterator rbegin() { return blocks.rbegin(); }
- reverse_iterator rend() { return blocks.rend(); }
- const_reverse_iterator rbegin() const { return blocks.rbegin(); }
- const_reverse_iterator rend() const { return blocks.rend(); }
-
- bool empty() const { return blocks.empty(); }
- void push_back(BasicBlock *block) { blocks.push_back(block); }
- void push_front(BasicBlock *block) { blocks.push_front(block); }
-
- BasicBlock &back() { return blocks.back(); }
- const BasicBlock &back() const {
- return const_cast<CFGFunction *>(this)->back();
- }
-
- BasicBlock &front() { return blocks.front(); }
- const BasicBlock &front() const {
- return const_cast<CFGFunction*>(this)->front();
- }
-
- //===--------------------------------------------------------------------===//
- // Other
- //===--------------------------------------------------------------------===//
-
- /// getSublistAccess() - Returns pointer to member of block list
- static BasicBlockListType CFGFunction::*getSublistAccess(BasicBlock*) {
- return &CFGFunction::blocks;
- }
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(const Function *func) {
- return func->getKind() == Kind::CFGFunc;
- }
-
- /// Displays the CFG in a window. This is for use from the debugger and
- /// depends on Graphviz to generate the graph.
- /// This function is defined in CFGFunctionViewGraph and only works with that
- /// target linked.
- void viewGraph() const;
-
-private:
- BasicBlockListType blocks;
-};
-
-} // end namespace mlir
-
-#endif // MLIR_IR_CFGFUNCTION_H
+++ /dev/null
-//===- CFGValue.h - CFGValue base class and SSA type decls ------*- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file defines SSA manipulation implementations for CFG functions.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_CFGVALUE_H
-#define MLIR_IR_CFGVALUE_H
-
-#include "mlir/IR/SSAValue.h"
-
-namespace mlir {
-class BasicBlock;
-class CFGValue;
-class CFGFunction;
-class Instruction;
-
-/// This enum contains all of the SSA value kinds that are valid in a CFG
-/// function. This should be kept as a proper subtype of SSAValueKind,
-/// including having all of the values of the enumerators align.
-enum class CFGValueKind {
- BBArgument = (int)SSAValueKind::BBArgument,
- InstResult = (int)SSAValueKind::InstResult,
-};
-
-/// The operand of a CFG Instruction contains a CFGValue.
-using InstOperand = IROperandImpl<CFGValue, Instruction>;
-
-/// CFGValue is the base class for SSA values in CFG functions.
-class CFGValue : public SSAValueImpl<InstOperand, Instruction, CFGValueKind> {
-public:
- static bool classof(const SSAValue *value) {
- switch (value->getKind()) {
- case SSAValueKind::BBArgument:
- case SSAValueKind::InstResult:
- return true;
-
- case SSAValueKind::BlockArgument:
- case SSAValueKind::StmtResult:
- case SSAValueKind::ForStmt:
- return false;
- }
- }
-
- /// Return the function that this CFGValue is defined in.
- CFGFunction *getFunction();
-
- /// Return the function that this CFGValue is defined in.
- const CFGFunction *getFunction() const {
- return const_cast<CFGValue *>(this)->getFunction();
- }
-
-protected:
- CFGValue(CFGValueKind kind, Type type) : SSAValueImpl(kind, type) {}
-};
-
-/// Basic block arguments are CFG Values.
-class BBArgument : public CFGValue {
-public:
- static bool classof(const SSAValue *value) {
- return value->getKind() == SSAValueKind::BBArgument;
- }
-
- /// Return the function that this argument is defined in.
- CFGFunction *getFunction();
- const CFGFunction *getFunction() const {
- return const_cast<BBArgument *>(this)->getFunction();
- }
-
- BasicBlock *getOwner() { return owner; }
- const BasicBlock *getOwner() const { return owner; }
-
-private:
- friend class BasicBlock; // For access to private constructor.
- BBArgument(Type type, BasicBlock *owner)
- : CFGValue(CFGValueKind::BBArgument, type), owner(owner) {}
-
- /// The owner of this operand.
- /// TODO: can encode this more efficiently to avoid the space hit of this
- /// through bitpacking shenanigans.
- BasicBlock *const owner;
-};
-
-/// Instruction results are CFG Values.
-class InstResult : public CFGValue {
-public:
- InstResult(Type type, Instruction *owner)
- : CFGValue(CFGValueKind::InstResult, type), owner(owner) {}
-
- static bool classof(const SSAValue *value) {
- return value->getKind() == SSAValueKind::InstResult;
- }
-
- Instruction *getOwner() { return owner; }
- const Instruction *getOwner() const { return owner; }
-
- /// Return the number of this result.
- unsigned getResultNumber() const;
-
-private:
- /// The owner of this operand.
- /// TODO: can encode this more efficiently to avoid the space hit of this
- /// through bitpacking shenanigans.
- Instruction *const owner;
-};
-
-} // namespace mlir
-
-#endif
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/Location.h"
+#include "mlir/IR/MLValue.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/StmtBlock.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ilist.h"
class FunctionType;
class MLIRContext;
class Module;
+template <typename ObjectType, typename ElementType> class ArgumentIterator;
/// NamedAttribute is used for function attribute lists, it holds an
/// identifier for the name and a value for the attribute. The attribute
public:
enum class Kind { ExtFunc, CFGFunc, MLFunc };
+ Function(Kind kind, Location location, StringRef name, FunctionType type,
+ ArrayRef<NamedAttribute> attrs = {});
+ ~Function();
+
Kind getKind() const { return (Kind)nameAndKind.getInt(); }
+ bool isCFG() const { return getKind() == Kind::CFGFunc; }
+ bool isML() const { return getKind() == Kind::MLFunc; }
+
/// The source location the operation was defined or derived from.
Location getLoc() const { return location; }
Module *getModule() { return module; }
const Module *getModule() const { return module; }
- /// Unlink this instruction from its module and delete it.
- void eraseFromModule();
+ /// Unlink this function from its module and delete it.
+ void erase();
+
+ //===--------------------------------------------------------------------===//
+ // Body Handling
+ //===--------------------------------------------------------------------===//
+
+ StmtBlockList &getBlockList() { return blocks; }
+ const StmtBlockList &getBlockList() const { return blocks; }
+
+ /// This is the list of blocks in the function.
+ using BlockListType = llvm::iplist<BasicBlock>;
+ BlockListType &getBlocks() { return blocks.getBlocks(); }
+ const BlockListType &getBlocks() const { return blocks.getBlocks(); }
+
+ // Iteration over the block in the function.
+ using iterator = BlockListType::iterator;
+ using const_iterator = BlockListType::const_iterator;
+ using reverse_iterator = BlockListType::reverse_iterator;
+ using const_reverse_iterator = BlockListType::const_reverse_iterator;
+
+ iterator begin() { return blocks.begin(); }
+ iterator end() { return blocks.end(); }
+ const_iterator begin() const { return blocks.begin(); }
+ const_iterator end() const { return blocks.end(); }
+ reverse_iterator rbegin() { return blocks.rbegin(); }
+ reverse_iterator rend() { return blocks.rend(); }
+ const_reverse_iterator rbegin() const { return blocks.rbegin(); }
+ const_reverse_iterator rend() const { return blocks.rend(); }
+
+ bool empty() const { return blocks.empty(); }
+ void push_back(BasicBlock *block) { blocks.push_back(block); }
+ void push_front(BasicBlock *block) { blocks.push_front(block); }
+
+ BasicBlock &back() { return blocks.back(); }
+ const BasicBlock &back() const {
+ return const_cast<CFGFunction *>(this)->back();
+ }
+
+ BasicBlock &front() { return blocks.front(); }
+ const BasicBlock &front() const {
+ return const_cast<CFGFunction *>(this)->front();
+ }
+
+ /// Return the 'return' statement of this MLFunction.
+ const OperationStmt *getReturnStmt() const;
+ OperationStmt *getReturnStmt();
- /// Delete this object.
- void destroy();
+ // These should only be used on MLFunctions.
+ StmtBlock *getBody() {
+ assert(isML());
+ return &blocks.front();
+ }
+ const StmtBlock *getBody() const {
+ return const_cast<Function *>(this)->getBody();
+ }
+
+ /// Walk the statements in the function in preorder, calling the callback for
+ /// each Operation statement.
+ void walk(std::function<void(OperationStmt *)> callback);
+
+ /// Walk the statements in the function in postorder, calling the callback for
+ /// each Operation statement.
+ void walkPostOrder(std::function<void(OperationStmt *)> callback);
+
+ //===--------------------------------------------------------------------===//
+ // Arguments
+ //===--------------------------------------------------------------------===//
+
+ /// Returns number of arguments.
+ unsigned getNumArguments() const { return getType().getInputs().size(); }
+
+ /// Gets argument.
+ BlockArgument *getArgument(unsigned idx) {
+ return getBlocks().front().getArgument(idx);
+ }
+
+ const BlockArgument *getArgument(unsigned idx) const {
+ return getBlocks().front().getArgument(idx);
+ }
+
+ // Supports non-const operand iteration.
+ using args_iterator = ArgumentIterator<MLFunction, BlockArgument>;
+ args_iterator args_begin();
+ args_iterator args_end();
+ llvm::iterator_range<args_iterator> getArguments();
+
+ // Supports const operand iteration.
+ using const_args_iterator =
+ ArgumentIterator<const MLFunction, const BlockArgument>;
+ const_args_iterator args_begin() const;
+ const_args_iterator args_end() const;
+ llvm::iterator_range<const_args_iterator> getArguments() const;
+
+ //===--------------------------------------------------------------------===//
+ // Other
+ //===--------------------------------------------------------------------===//
/// Perform (potentially expensive) checks of invariants, used to detect
/// compiler bugs. On error, this reports the error through the MLIRContext
/// handlers that may be listening.
void emitNote(const Twine &message) const;
-protected:
- Function(Kind kind, Location location, StringRef name, FunctionType type,
- ArrayRef<NamedAttribute> attrs = {});
- ~Function();
+ /// Displays the CFG in a window. This is for use from the debugger and
+ /// depends on Graphviz to generate the graph.
+ /// This function is defined in CFGFunctionViewGraph and only works with that
+ /// target linked.
+ void viewGraph() const;
private:
/// The name of the function and the kind of function this is.
/// This holds general named attributes for the function.
AttributeListStorage *attrs;
+ /// The contents of the body.
+ StmtBlockList blocks;
+
void operator=(const Function &) = delete;
friend struct llvm::ilist_traits<Function>;
};
-/// An extfunc declaration is a declaration of a function signature that is
-/// defined in some other module.
-class ExtFunction : public Function {
+//===--------------------------------------------------------------------===//
+// ArgumentIterator
+//===--------------------------------------------------------------------===//
+
+/// This template implements the argument iterator in terms of getArgument(idx).
+template <typename ObjectType, typename ElementType>
+class ArgumentIterator final
+ : public IndexedAccessorIterator<ArgumentIterator<ObjectType, ElementType>,
+ ObjectType, ElementType> {
public:
- ExtFunction(Location location, StringRef name, FunctionType type,
- ArrayRef<NamedAttribute> attrs = {});
+ /// Initializes the result iterator to the specified index.
+ ArgumentIterator(ObjectType *object, unsigned index)
+ : IndexedAccessorIterator<ArgumentIterator<ObjectType, ElementType>,
+ ObjectType, ElementType>(object, index) {}
+
+ /// Support converting to the const variant. This will be a no-op for const
+ /// variant.
+ operator ArgumentIterator<const ObjectType, const ElementType>() const {
+ return ArgumentIterator<const ObjectType, const ElementType>(this->object,
+ this->index);
+ }
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(const Function *func) {
- return func->getKind() == Kind::ExtFunc;
+ ElementType *operator*() const {
+ return this->object->getArgument(this->index);
}
};
+//===--------------------------------------------------------------------===//
+// MLFunction iterator methods.
+//===--------------------------------------------------------------------===//
+
+inline MLFunction::args_iterator MLFunction::args_begin() {
+ return args_iterator(this, 0);
+}
+
+inline MLFunction::args_iterator MLFunction::args_end() {
+ return args_iterator(this, getNumArguments());
+}
+
+inline llvm::iterator_range<MLFunction::args_iterator>
+MLFunction::getArguments() {
+ return {args_begin(), args_end()};
+}
+
+inline MLFunction::const_args_iterator MLFunction::args_begin() const {
+ return const_args_iterator(this, 0);
+}
+
+inline MLFunction::const_args_iterator MLFunction::args_end() const {
+ return const_args_iterator(this, getNumArguments());
+}
+
+inline llvm::iterator_range<MLFunction::const_args_iterator>
+MLFunction::getArguments() const {
+ return {args_begin(), args_end()};
+}
+
} // end namespace mlir
//===----------------------------------------------------------------------===//
using Function = ::mlir::Function;
using function_iterator = simple_ilist<Function>::iterator;
- static void deleteNode(Function *function) { function->destroy(); }
+ static void deleteNode(Function *function) { delete function; }
void addNodeToList(Function *function);
void removeNodeFromList(Function *function);
#ifndef MLIR_IR_CFGFUNCTIONGRAPHTRAITS_H
#define MLIR_IR_CFGFUNCTIONGRAPHTRAITS_H
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Function.h"
#include "llvm/ADT/GraphTraits.h"
namespace llvm {
}
};
-template <> struct GraphTraits<mlir::StmtBlock *> {
- using ChildIteratorType = mlir::StmtBlock::succ_iterator;
- using Node = mlir::StmtBlock;
- using NodeRef = Node *;
-
- static NodeRef getEntryNode(NodeRef bb) { return bb; }
-
- static ChildIteratorType child_begin(NodeRef node) {
- return node->succ_begin();
- }
- static ChildIteratorType child_end(NodeRef node) { return node->succ_end(); }
-};
-
-template <> struct GraphTraits<const mlir::StmtBlock *> {
- using ChildIteratorType = mlir::StmtBlock::const_succ_iterator;
- using Node = const mlir::StmtBlock;
- using NodeRef = Node *;
-
- static NodeRef getEntryNode(NodeRef bb) { return bb; }
-
- static ChildIteratorType child_begin(NodeRef node) {
- return node->succ_begin();
- }
- static ChildIteratorType child_end(NodeRef node) { return node->succ_end(); }
-};
-
-template <> struct GraphTraits<Inverse<mlir::StmtBlock *>> {
- using ChildIteratorType = mlir::StmtBlock::pred_iterator;
- using Node = mlir::StmtBlock;
- using NodeRef = Node *;
- static NodeRef getEntryNode(Inverse<NodeRef> inverseGraph) {
- return inverseGraph.Graph;
- }
- static inline ChildIteratorType child_begin(NodeRef node) {
- return node->pred_begin();
- }
- static inline ChildIteratorType child_end(NodeRef node) {
- return node->pred_end();
- }
-};
-
-template <> struct GraphTraits<Inverse<const mlir::StmtBlock *>> {
- using ChildIteratorType = mlir::StmtBlock::const_pred_iterator;
- using Node = const mlir::StmtBlock;
- using NodeRef = Node *;
-
- static NodeRef getEntryNode(Inverse<NodeRef> inverseGraph) {
- return inverseGraph.Graph;
- }
- static inline ChildIteratorType child_begin(NodeRef node) {
- return node->pred_begin();
- }
- static inline ChildIteratorType child_end(NodeRef node) {
- return node->pred_end();
- }
-};
-
template <>
-struct GraphTraits<mlir::MLFunction *> : public GraphTraits<mlir::StmtBlock *> {
- using GraphType = mlir::MLFunction *;
- using NodeRef = mlir::StmtBlock *;
+struct GraphTraits<mlir::StmtBlockList *>
+ : public GraphTraits<mlir::BasicBlock *> {
+ using GraphType = mlir::StmtBlockList *;
+ using NodeRef = mlir::BasicBlock *;
- static NodeRef getEntryNode(GraphType fn) {
- return &fn->getBlockList().front();
- }
+ static NodeRef getEntryNode(GraphType fn) { return &fn->front(); }
- using nodes_iterator = pointer_iterator<mlir::StmtBlockList::iterator>;
+ using nodes_iterator = pointer_iterator<mlir::CFGFunction::iterator>;
static nodes_iterator nodes_begin(GraphType fn) {
- return nodes_iterator(fn->getBlockList().begin());
+ return nodes_iterator(fn->begin());
}
static nodes_iterator nodes_end(GraphType fn) {
- return nodes_iterator(fn->getBlockList().end());
+ return nodes_iterator(fn->end());
}
};
template <>
-struct GraphTraits<const mlir::MLFunction *>
- : public GraphTraits<const mlir::StmtBlock *> {
- using GraphType = const mlir::MLFunction *;
- using NodeRef = const mlir::StmtBlock *;
+struct GraphTraits<const mlir::StmtBlockList *>
+ : public GraphTraits<const mlir::BasicBlock *> {
+ using GraphType = const mlir::StmtBlockList *;
+ using NodeRef = const mlir::BasicBlock *;
- static NodeRef getEntryNode(GraphType fn) {
- return &fn->getBlockList().front();
- }
+ static NodeRef getEntryNode(GraphType fn) { return &fn->front(); }
- using nodes_iterator = pointer_iterator<mlir::StmtBlockList::const_iterator>;
+ using nodes_iterator = pointer_iterator<mlir::CFGFunction::const_iterator>;
static nodes_iterator nodes_begin(GraphType fn) {
- return nodes_iterator(fn->getBlockList().begin());
+ return nodes_iterator(fn->begin());
}
static nodes_iterator nodes_end(GraphType fn) {
- return nodes_iterator(fn->getBlockList().end());
+ return nodes_iterator(fn->end());
}
};
template <>
-struct GraphTraits<Inverse<mlir::MLFunction *>>
- : public GraphTraits<Inverse<mlir::StmtBlock *>> {
- using GraphType = Inverse<mlir::MLFunction *>;
+struct GraphTraits<Inverse<mlir::StmtBlockList *>>
+ : public GraphTraits<Inverse<mlir::BasicBlock *>> {
+ using GraphType = Inverse<mlir::StmtBlockList *>;
using NodeRef = NodeRef;
- static NodeRef getEntryNode(GraphType fn) {
- return &fn.Graph->getBlockList().front();
- }
+ static NodeRef getEntryNode(GraphType fn) { return &fn.Graph->front(); }
- using nodes_iterator = pointer_iterator<mlir::StmtBlockList::iterator>;
+ using nodes_iterator = pointer_iterator<mlir::CFGFunction::iterator>;
static nodes_iterator nodes_begin(GraphType fn) {
- return nodes_iterator(fn.Graph->getBlockList().begin());
+ return nodes_iterator(fn.Graph->begin());
}
static nodes_iterator nodes_end(GraphType fn) {
- return nodes_iterator(fn.Graph->getBlockList().end());
+ return nodes_iterator(fn.Graph->end());
}
};
template <>
-struct GraphTraits<Inverse<const mlir::MLFunction *>>
- : public GraphTraits<Inverse<const mlir::StmtBlock *>> {
- using GraphType = Inverse<const mlir::MLFunction *>;
+struct GraphTraits<Inverse<const mlir::StmtBlockList *>>
+ : public GraphTraits<Inverse<const mlir::BasicBlock *>> {
+ using GraphType = Inverse<const mlir::StmtBlockList *>;
using NodeRef = NodeRef;
- static NodeRef getEntryNode(GraphType fn) {
- return &fn.Graph->getBlockList().front();
- }
+ static NodeRef getEntryNode(GraphType fn) { return &fn.Graph->front(); }
- using nodes_iterator = pointer_iterator<mlir::StmtBlockList::const_iterator>;
+ using nodes_iterator = pointer_iterator<mlir::CFGFunction::const_iterator>;
static nodes_iterator nodes_begin(GraphType fn) {
- return nodes_iterator(fn.Graph->getBlockList().begin());
+ return nodes_iterator(fn.Graph->begin());
}
static nodes_iterator nodes_end(GraphType fn) {
- return nodes_iterator(fn.Graph->getBlockList().end());
+ return nodes_iterator(fn.Graph->end());
}
};
+++ /dev/null
-//===- Instructions.h - MLIR CFG Instruction Classes ------------*- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file defines the classes for CFGFunction instructions.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_INSTRUCTIONS_H
-#define MLIR_IR_INSTRUCTIONS_H
-
-#include "mlir/IR/CFGValue.h"
-#include "mlir/IR/Identifier.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/ilist.h"
-#include "llvm/Support/TrailingObjects.h"
-
-namespace mlir {
-class BasicBlock;
-class CFGFunction;
-} // end namespace mlir
-
-//===----------------------------------------------------------------------===//
-// ilist_traits for Instruction
-//===----------------------------------------------------------------------===//
-
-namespace llvm {
-
-template <> struct ilist_traits<::mlir::Instruction> {
- using Instruction = ::mlir::Instruction;
- using instr_iterator = simple_ilist<Instruction>::iterator;
-
- static void deleteNode(Instruction *inst);
- void addNodeToList(Instruction *inst);
- void removeNodeFromList(Instruction *inst);
- void transferNodesFromList(ilist_traits<Instruction> &otherList,
- instr_iterator first, instr_iterator last);
-
-private:
- mlir::BasicBlock *getContainingBlock();
-};
-
-} // end namespace llvm
-
-namespace mlir {
-
-/// The operand of a Terminator contains a BasicBlock.
-using BasicBlockOperand = IROperandImpl<BasicBlock, Instruction>;
-
-// The trailing objects of an instruction are layed out as follows:
-// - InstResult : The results of the instruction.
-// - InstOperand : The operands of the instruction. For terminators, the
-// operands for all successors are concatenated here.
-// - BasicBlockOperand : Use-list of successor blocks if this is a terminator.
-// - unsigned : Count of operands held for each of the successors.
-//
-// Note: For Terminators, we rely on the assumption that all non-successor
-// operands are placed at the beginning of the operands list.
-class Instruction final
- : public Operation,
- public IROperandOwner,
- public llvm::ilist_node_with_parent<Instruction, BasicBlock>,
- private llvm::TrailingObjects<Instruction, InstResult, InstOperand,
- BasicBlockOperand, unsigned> {
-public:
- using IROperandOwner::getContext;
- using IROperandOwner::getLoc;
- using IROperandOwner::setLoc;
-
- //===--------------------------------------------------------------------===//
- // Operands
- //===--------------------------------------------------------------------===//
-
- unsigned getNumOperands() const { return numOperands; }
-
- CFGValue *getOperand(unsigned idx) { return getInstOperand(idx).get(); }
- const CFGValue *getOperand(unsigned idx) const {
- return getInstOperand(idx).get();
- }
- void setOperand(unsigned idx, CFGValue *value) {
- getInstOperand(idx).set(value);
- }
-
- // Support non-const operand iteration.
- using operand_iterator = OperandIterator<Instruction, CFGValue>;
-
- operand_iterator operand_begin() { return operand_iterator(this, 0); }
-
- operand_iterator operand_end() {
- return operand_iterator(this, getNumOperands());
- }
-
- llvm::iterator_range<operand_iterator> getOperands() {
- return {operand_begin(), operand_end()};
- }
-
- // Support const operand iteration.
- using const_operand_iterator =
- OperandIterator<const Instruction, const CFGValue>;
-
- const_operand_iterator operand_begin() const {
- return const_operand_iterator(this, 0);
- }
-
- const_operand_iterator operand_end() const {
- return const_operand_iterator(this, getNumOperands());
- }
-
- llvm::iterator_range<const_operand_iterator> getOperands() const {
- return {operand_begin(), operand_end()};
- }
-
- ArrayRef<InstOperand> getInstOperands() const {
- return {getTrailingObjects<InstOperand>(), numOperands};
- }
-
- MutableArrayRef<InstOperand> getInstOperands() {
- return {getTrailingObjects<InstOperand>(), numOperands};
- }
-
- // Accessors to InstOperand.
- InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; }
- const InstOperand &getInstOperand(unsigned idx) const {
- return getInstOperands()[idx];
- }
-
- //===--------------------------------------------------------------------===//
- // Results
- //===--------------------------------------------------------------------===//
-
- unsigned getNumResults() const { return numResults; }
-
- CFGValue *getResult(unsigned idx) { return &getInstResult(idx); }
- const CFGValue *getResult(unsigned idx) const { return &getInstResult(idx); }
-
- // Support non-const result iteration.
- using result_iterator = ResultIterator<Instruction, CFGValue>;
- result_iterator result_begin() { return result_iterator(this, 0); }
- result_iterator result_end() {
- return result_iterator(this, getNumResults());
- }
- llvm::iterator_range<result_iterator> getResults() {
- return {result_begin(), result_end()};
- }
-
- // Support const result iteration.
- using const_result_iterator =
- ResultIterator<const Instruction, const CFGValue>;
- const_result_iterator result_begin() const {
- return const_result_iterator(this, 0);
- }
-
- const_result_iterator result_end() const {
- return const_result_iterator(this, getNumResults());
- }
-
- llvm::iterator_range<const_result_iterator> getResults() const {
- return {result_begin(), result_end()};
- }
-
- ArrayRef<InstResult> getInstResults() const {
- return {getTrailingObjects<InstResult>(), numResults};
- }
-
- MutableArrayRef<InstResult> getInstResults() {
- return {getTrailingObjects<InstResult>(), numResults};
- }
-
- InstResult &getInstResult(unsigned idx) { return getInstResults()[idx]; }
-
- const InstResult &getInstResult(unsigned idx) const {
- return getInstResults()[idx];
- }
-
- // Support result type iteration.
- using result_type_iterator =
- ResultTypeIterator<const Instruction, const CFGValue>;
- result_type_iterator result_type_begin() const {
- return result_type_iterator(this, 0);
- }
-
- result_type_iterator result_type_end() const {
- return result_type_iterator(this, getNumResults());
- }
-
- llvm::iterator_range<result_type_iterator> getResultTypes() const {
- return {result_type_begin(), result_type_end()};
- }
-
- //===--------------------------------------------------------------------===//
- // Terminators
- //===--------------------------------------------------------------------===//
-
- MutableArrayRef<BasicBlockOperand> getBasicBlockOperands() {
- assert(isTerminator() && "Only terminators have a block operands list.");
- return {getTrailingObjects<BasicBlockOperand>(), numSuccs};
- }
- ArrayRef<BasicBlockOperand> getBasicBlockOperands() const {
- return const_cast<Instruction *>(this)->getBasicBlockOperands();
- }
-
- MutableArrayRef<InstOperand> getSuccessorInstOperands(unsigned index) {
- assert(isTerminator() && "Only terminators have successors.");
- assert(index < getNumSuccessors());
- unsigned succOpIndex = getSuccessorOperandIndex(index);
- auto *operandBegin = getInstOperands().data() + succOpIndex;
- return {operandBegin, getNumSuccessorOperands(index)};
- }
- ArrayRef<InstOperand> getSuccessorInstOperands(unsigned index) const {
- return const_cast<Instruction *>(this)->getSuccessorInstOperands(index);
- }
-
- unsigned getNumSuccessors() const { return getBasicBlockOperands().size(); }
- unsigned getNumSuccessorOperands(unsigned index) const {
- assert(isTerminator() && "Only terminators have successors.");
- assert(index < getNumSuccessors());
- return getTrailingObjects<unsigned>()[index];
- }
-
- BasicBlock *getSuccessor(unsigned index) {
- assert(index < getNumSuccessors());
- return getBasicBlockOperands()[index].get();
- }
- const BasicBlock *getSuccessor(unsigned index) const {
- return const_cast<Instruction *>(this)->getSuccessor(index);
- }
- void setSuccessor(BasicBlock *block, unsigned index);
-
- /// Erase a specific operand from the operand list of the successor at
- /// 'index'.
- void eraseSuccessorOperand(unsigned succIndex, unsigned opIndex) {
- assert(succIndex < getNumSuccessors());
- assert(opIndex < getNumSuccessorOperands(succIndex));
- eraseOperand(getSuccessorOperandIndex(succIndex) + opIndex);
- --getTrailingObjects<unsigned>()[succIndex];
- }
-
- /// Get the index of the first operand of the successor at the provided
- /// index.
- unsigned getSuccessorOperandIndex(unsigned index) const {
- assert(isTerminator() && "Only terminators have successors.");
- assert(index < getNumSuccessors());
-
- // Count the number of operands for each of the successors after, and
- // including, the one at 'index'. This is based upon the assumption that all
- // non successor operands are placed at the beginning of the operand list.
- auto *successorOpCountBegin = getTrailingObjects<unsigned>();
- unsigned postSuccessorOpCount =
- std::accumulate(successorOpCountBegin + index,
- successorOpCountBegin + getNumSuccessors(), 0);
- return getNumOperands() - postSuccessorOpCount;
- }
-
- //===--------------------------------------------------------------------===//
- // Other
- //===--------------------------------------------------------------------===//
-
- /// Create a new Instruction with the specified fields.
- static Instruction *
- create(Location location, OperationName name, ArrayRef<CFGValue *> operands,
- ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes,
- ArrayRef<BasicBlock *> successors, MLIRContext *context);
-
- Instruction *clone() const;
-
- /// Return the BasicBlock containing this instruction.
- const BasicBlock *getBlock() const { return block; }
- BasicBlock *getBlock() { return block; }
-
- /// Return the CFGFunction containing this instruction.
- CFGFunction *getFunction();
- const CFGFunction *getFunction() const {
- return const_cast<Instruction *>(this)->getFunction();
- }
-
- void print(raw_ostream &os) const;
- void dump() const;
-
- /// Unlink this instruction from its BasicBlock and delete it.
- void erase();
-
- /// Destroy this instruction and its subclass data.
- void destroy();
-
- /// Unlink this instruction from its current basic block and insert
- /// it right before `existingInst` which may be in the same or another block
- /// of the same function.
- void moveBefore(Instruction *existingInst);
-
- /// Unlink this instruction from its current basic block and insert
- /// it right before `iterator` in the specified basic block, which must be in
- /// the same function.
- void moveBefore(BasicBlock *block,
- llvm::iplist<Instruction>::iterator iterator);
-
- /// This drops all operand uses from this instruction, which is an essential
- /// step in breaking cyclic dependences between references when they are to
- /// be deleted.
- void dropAllReferences();
-
- /// Emit an error about fatal conditions with this operation, reporting up to
- /// any diagnostic handlers that may be listening. This function always
- /// returns true. NOTE: This may terminate the containing application, only
- /// use when the IR is in an inconsistent state.
- bool emitError(const Twine &message) const;
-
- /// Emit a warning about this operation, reporting up to any diagnostic
- /// handlers that may be listening.
- void emitWarning(const Twine &message) const;
-
- /// Emit a note about this operation, reporting up to any diagnostic
- /// handlers that may be listening.
- void emitNote(const Twine &message) const;
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(const IROperandOwner *ptr) {
- return ptr->getKind() == IROperandOwner::Kind::Instruction;
- }
- static bool classof(const Operation *op) {
- return op->getOperationKind() == OperationKind::Instruction;
- }
-
-private:
- const unsigned numResults, numSuccs;
- /// The number of operands tail-allocated, mutable to support erasing.
- unsigned numOperands;
- BasicBlock *block = nullptr;
-
- Instruction(Location location, OperationName name, unsigned numResults,
- unsigned numOperands, unsigned numSuccessors,
- ArrayRef<NamedAttribute> attributes, MLIRContext *context);
-
- // Instructions are deleted through the destroy() member because this class
- // does not have a virtual destructor.
- ~Instruction();
-
- /// Erase the operand at 'index'.
- void eraseOperand(unsigned index);
-
- friend struct llvm::ilist_traits<Instruction>;
- friend class BasicBlock;
-
- // This stuff is used by the TrailingObjects template.
- friend llvm::TrailingObjects<Instruction, InstResult, InstOperand,
- BasicBlockOperand, unsigned>;
- size_t numTrailingObjects(OverloadToken<InstResult>) const {
- return numResults;
- }
- size_t numTrailingObjects(OverloadToken<InstOperand>) const {
- return numOperands;
- }
- size_t numTrailingObjects(OverloadToken<BasicBlockOperand>) const {
- return numSuccs;
- }
- size_t numTrailingObjects(OverloadToken<unsigned>) const { return numSuccs; }
-};
-
-inline raw_ostream &operator<<(raw_ostream &os, const Instruction &inst) {
- inst.print(os);
- return os;
-}
-
-} // end namespace mlir
-
-#endif // MLIR_IR_INSTRUCTIONS_H
+++ /dev/null
-//===- MLFunction.h - MLIR MLFunction Class ---------------------*- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file defines MLFunction class
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_MLFUNCTION_H_
-#define MLIR_IR_MLFUNCTION_H_
-
-#include "mlir/IR/Function.h"
-#include "mlir/IR/MLValue.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/StmtBlock.h"
-#include "llvm/Support/TrailingObjects.h"
-
-namespace mlir {
-
-template <typename ObjectType, typename ElementType> class ArgumentIterator;
-
-// MLFunction is defined as a sequence of statements that may
-// include nested affine for loops, conditionals and operations.
-class MLFunction final : public Function {
-public:
- MLFunction(Location location, StringRef name, FunctionType type,
- ArrayRef<NamedAttribute> attrs = {});
-
- // TODO(clattner): drop this, it is redundant.
- static MLFunction *create(Location location, StringRef name,
- FunctionType type,
- ArrayRef<NamedAttribute> attrs = {}) {
- return new MLFunction(location, name, type, attrs);
- }
-
- StmtBlockList &getBlockList() { return body; }
- const StmtBlockList &getBlockList() const { return body; }
-
- StmtBlock *getBody() { return &body.front(); }
- const StmtBlock *getBody() const { return &body.front(); }
-
- //===--------------------------------------------------------------------===//
- // Arguments
- //===--------------------------------------------------------------------===//
-
- /// Returns number of arguments.
- unsigned getNumArguments() const { return getType().getInputs().size(); }
-
- /// Gets argument.
- BlockArgument *getArgument(unsigned idx) {
- return getBlockList().front().getArgument(idx);
- }
-
- const BlockArgument *getArgument(unsigned idx) const {
- return getBlockList().front().getArgument(idx);
- }
-
- // Supports non-const operand iteration.
- using args_iterator = ArgumentIterator<MLFunction, BlockArgument>;
- args_iterator args_begin();
- args_iterator args_end();
- llvm::iterator_range<args_iterator> getArguments();
-
- // Supports const operand iteration.
- using const_args_iterator =
- ArgumentIterator<const MLFunction, const BlockArgument>;
- const_args_iterator args_begin() const;
- const_args_iterator args_end() const;
- llvm::iterator_range<const_args_iterator> getArguments() const;
-
- //===--------------------------------------------------------------------===//
- // Other
- //===--------------------------------------------------------------------===//
-
- ~MLFunction();
-
- // Return the 'return' statement of this MLFunction.
- const OperationStmt *getReturnStmt() const;
- OperationStmt *getReturnStmt();
-
- /// Walk the statements in the function in preorder, calling the callback for
- /// each Operation statement.
- void walk(std::function<void(OperationStmt *)> callback);
-
- /// Walk the statements in the function in postorder, calling the callback for
- /// each Operation statement.
- void walkPostOrder(std::function<void(OperationStmt *)> callback);
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(const Function *func) {
- return func->getKind() == Function::Kind::MLFunc;
- }
-
-private:
- StmtBlockList body;
-};
-
-//===--------------------------------------------------------------------===//
-// ArgumentIterator
-//===--------------------------------------------------------------------===//
-
-/// This template implements the argument iterator in terms of getArgument(idx).
-template <typename ObjectType, typename ElementType>
-class ArgumentIterator final
- : public IndexedAccessorIterator<ArgumentIterator<ObjectType, ElementType>,
- ObjectType, ElementType> {
-public:
- /// Initializes the result iterator to the specified index.
- ArgumentIterator(ObjectType *object, unsigned index)
- : IndexedAccessorIterator<ArgumentIterator<ObjectType, ElementType>,
- ObjectType, ElementType>(object, index) {}
-
- /// Support converting to the const variant. This will be a no-op for const
- /// variant.
- operator ArgumentIterator<const ObjectType, const ElementType>() const {
- return ArgumentIterator<const ObjectType, const ElementType>(this->object,
- this->index);
- }
-
- ElementType *operator*() const {
- return this->object->getArgument(this->index);
- }
-};
-
-//===--------------------------------------------------------------------===//
-// MLFunction iterator methods.
-//===--------------------------------------------------------------------===//
-
-inline MLFunction::args_iterator MLFunction::args_begin() {
- return args_iterator(this, 0);
-}
-
-inline MLFunction::args_iterator MLFunction::args_end() {
- return args_iterator(this, getNumArguments());
-}
-
-inline llvm::iterator_range<MLFunction::args_iterator>
-MLFunction::getArguments() {
- return {args_begin(), args_end()};
-}
-
-inline MLFunction::const_args_iterator MLFunction::args_begin() const {
- return const_args_iterator(this, 0);
-}
-
-inline MLFunction::const_args_iterator MLFunction::args_end() const {
- return const_args_iterator(this, getNumArguments());
-}
-
-inline llvm::iterator_range<MLFunction::const_args_iterator>
-MLFunction::getArguments() const {
- return {args_begin(), args_end()};
-}
-
-} // end namespace mlir
-
-#endif // MLIR_IR_MLFUNCTION_H_
namespace mlir {
class ForStmt;
class MLValue;
-class MLFunction;
+using MLFunction = Function;
class Statement;
class StmtBlock;
case SSAValueKind::StmtResult:
case SSAValueKind::ForStmt:
return true;
-
- case SSAValueKind::BBArgument:
- case SSAValueKind::InstResult:
- return false;
}
}
OperationStmt *const owner;
};
+// TODO(clattner) clean all this up.
+using CFGValue = MLValue;
+using BBArgument = BlockArgument;
+using InstResult = StmtResult;
+
} // namespace mlir
#endif
#include <type_traits>
namespace mlir {
-class BasicBlock;
class Builder;
namespace OpTrait {
namespace mlir {
class AttributeListStorage;
-class BasicBlock;
template <typename OpType> class ConstOpPointer;
template <typename OpType> class OpPointer;
template <typename ObjectType, typename ElementType> class OperandIterator;
template <typename ObjectType, typename ElementType> class ResultTypeIterator;
class Function;
class IROperandOwner;
-class Instruction;
class Statement;
+class OperationStmt;
+using Instruction = Statement;
/// Operations represent all of the arithmetic and other basic computation in
/// MLIR. This class is the common implementation details behind Instruction
#include <memory>
namespace mlir {
-class BasicBlock;
class Dialect;
class Operation;
class OperationState;
class StmtBlock;
class SSAValue;
class Type;
+using BasicBlock = StmtBlock;
/// This is a vector that owns the patterns inside of it.
using OwningPatternList = std::vector<std::unique_ptr<Pattern>>;
SmallVector<Type, 4> types;
SmallVector<NamedAttribute, 4> attributes;
/// Successors of this operation and their respective operands.
- SmallVector<BasicBlock *, 1> successors;
-
- // TODO: rename to successors when CFG and ML Functions are merged.
- SmallVector<StmtBlock *, 1> successorsS;
+ SmallVector<StmtBlock *, 1> successors;
public:
OperationState(MLIRContext *context, Location location, StringRef name)
OperationState(MLIRContext *context, Location location, OperationName name)
: context(context), location(location), name(name) {}
- OperationState(MLIRContext *context, Location location, StringRef name,
- ArrayRef<SSAValue *> operands, ArrayRef<Type> types,
- ArrayRef<NamedAttribute> attributes = {},
- ArrayRef<BasicBlock *> successors = {})
- : context(context), location(location), name(name, context),
- operands(operands.begin(), operands.end()),
- types(types.begin(), types.end()),
- attributes(attributes.begin(), attributes.end()),
- successors(successors.begin(), successors.end()) {}
-
OperationState(MLIRContext *context, Location location, StringRef name,
ArrayRef<SSAValue *> operands, ArrayRef<Type> types,
ArrayRef<NamedAttribute> attributes,
- ArrayRef<StmtBlock *> successors)
+ ArrayRef<StmtBlock *> successors = {})
: context(context), location(location), name(name, context),
operands(operands.begin(), operands.end()),
types(types.begin(), types.end()),
attributes(attributes.begin(), attributes.end()),
- successorsS(successors.begin(), successors.end()) {}
+ successors(successors.begin(), successors.end()) {}
void addOperands(ArrayRef<SSAValue *> newOperands) {
assert(successors.empty() &&
attributes.push_back({name, attr});
}
- void addSuccessor(BasicBlock *successor, ArrayRef<SSAValue *> succOperands) {
+ void addSuccessor(StmtBlock *successor, ArrayRef<SSAValue *> succOperands) {
successors.push_back(successor);
// Insert a sentinal operand to mark a barrier between successor operands.
operands.push_back(nullptr);
namespace mlir {
class Function;
-class Instruction;
class OperationStmt;
class Operation;
+class Statement;
+using Instruction = Statement;
+using OperationInst = OperationStmt;
/// This enumerates all of the SSA value kinds in the MLIR system.
enum class SSAValueKind {
- BBArgument, // basic block argument
- InstResult, // instruction result
BlockArgument, // Block argument
StmtResult, // statement result
ForStmt, // for statement induction variable
/// If this value is the result of an Instruction, return the instruction
/// that defines it.
- Instruction *getDefiningInst();
- const Instruction *getDefiningInst() const {
+ OperationInst *getDefiningInst();
+ const OperationInst *getDefiningInst() const {
return const_cast<SSAValue *>(this)->getDefiningInst();
}
namespace mlir {
class Location;
-class MLFunction;
+using MLFunction = Function;
class StmtBlock;
class ForStmt;
class MLIRContext;
/// sub-statements to the corresponding statement that is copied, and adds
/// those mappings to the map.
Statement *clone(OperandMapTy &operandMap, MLIRContext *context) const;
+ Statement *clone(MLIRContext *context) const;
/// Returns the statement block that contains this statement.
StmtBlock *getBlock() const { return block; }
/// Destroys this statement and its subclass data.
void destroy();
+ /// This drops all operand uses from this instruction, which is an essential
+ /// step in breaking cyclic dependences between references when they are to
+ /// be deleted.
+ void dropAllReferences();
+
/// Unlink this statement from its current block and insert it right before
/// `existingStmt` which may be in the same or another block in the same
/// function.
/// it right before `iterator` in the specified basic block.
void moveBefore(StmtBlock *block, llvm::iplist<Statement>::iterator iterator);
+ // Returns whether the Statement is a terminator.
+ bool isTerminator() const;
+
void print(raw_ostream &os) const;
void dump() const;
class IntegerSet;
class AffineCondition;
class OperationStmt;
+using OperationInst = OperationStmt;
/// Operation statements represent operations inside ML functions.
class OperationStmt final
/// Return the context this operation is associated with.
MLIRContext *getContext() const;
+ using Operation::isTerminator;
using Statement::dump;
using Statement::emitError;
using Statement::emitNote;
}
void setSuccessor(BasicBlock *block, unsigned index);
+ /// Erase a specific operand from the operand list of the successor at
+ /// 'index'.
+ void eraseSuccessorOperand(unsigned succIndex, unsigned opIndex) {
+ assert(succIndex < getNumSuccessors());
+ assert(opIndex < getNumSuccessorOperands(succIndex));
+ eraseOperand(getSuccessorOperandIndex(succIndex) + opIndex);
+ --getTrailingObjects<unsigned>()[succIndex];
+ }
+
/// Get the index of the first operand of the successor at the provided
/// index.
unsigned getSuccessorOperandIndex(unsigned index) const {
}
private:
- const unsigned numOperands, numResults, numSuccs;
+ unsigned numOperands;
+ const unsigned numResults, numSuccs;
OperationStmt(Location location, OperationName name, unsigned numOperands,
unsigned numResults, unsigned numSuccessors,
ArrayRef<NamedAttribute> attributes, MLIRContext *context);
~OperationStmt();
+ /// Erase the operand at 'index'.
+ void eraseOperand(unsigned index);
+
// This stuff is used by the TrailingObjects template.
friend llvm::TrailingObjects<OperationStmt, StmtResult, StmtBlockOperand,
unsigned, StmtOperand>;
#include "mlir/IR/Statement.h"
namespace mlir {
-class MLFunction;
class IfStmt;
class MLValue;
class StmtBlockList;
+using CFGFunction = Function;
+using MLFunction = Function;
// TODO(clattner): drop the Stmt prefixes on these once BasicBlock's versions of
// these go away.
succ_iterator succ_end();
llvm::iterator_range<succ_iterator> getSuccessors();
+ //===--------------------------------------------------------------------===//
+ // Other
+ //===--------------------------------------------------------------------===//
+
+ /// Unlink this BasicBlock from its CFGFunction and delete it.
+ void eraseFromFunction();
+
+ /// Split the basic block into two basic blocks before the specified
+ /// instruction or iterator.
+ ///
+ /// Note that all instructions BEFORE the specified iterator stay as part of
+ /// the original basic block, an unconditional branch is added to the original
+ /// block (going to the new block), and the rest of the instructions in the
+ /// original block are moved to the new BB, including the old terminator. The
+ /// newly formed BasicBlock is returned.
+ ///
+ /// This function invalidates the specified iterator.
+ BasicBlock *splitBasicBlock(iterator splitBefore);
+ BasicBlock *splitBasicBlock(Instruction *splitBeforeInst) {
+ return splitBasicBlock(iterator(splitBeforeInst));
+ }
+
/// getSublistAccess() - Returns pointer to member of statement list
static StmtListType StmtBlock::*getSublistAccess(Statement *) {
return &StmtBlock::statements;
}
- /// These have unconventional names to avoid derive class ambiguities.
- void printBlock(raw_ostream &os) const;
- void dumpBlock() const;
+ void print(raw_ostream &os) const;
+ void dump() const;
+
+ /// Print out the name of the basic block without printing its body.
+ /// NOTE: The printType argument is ignored. We keep it for compatibility
+ /// with LLVM dominator machinery that expects it to exist.
+ void printAsOperand(raw_ostream &os, bool printType = true);
private:
/// This is the parent function/IfStmt/ForStmt that owns this block.
/// is part of - an MLFunction or IfStmt or ForStmt.
class StmtBlockList {
public:
- explicit StmtBlockList(MLFunction *container);
+ explicit StmtBlockList(Function *container);
explicit StmtBlockList(Statement *container);
using BlockListType = llvm::iplist<StmtBlock>;
return &StmtBlockList::blocks;
}
- /// A StmtBlockList is part of a MLFunction or and IfStmt/ForStmt. If it is
+ /// A StmtBlockList is part of a Function or and IfStmt/ForStmt. If it is
/// part of an IfStmt/ForStmt, then return it, otherwise return null.
Statement *getContainingStmt();
const Statement *getContainingStmt() const {
return const_cast<StmtBlockList *>(this)->getContainingStmt();
}
- /// A StmtBlockList is part of a MLFunction or and IfStmt/ForStmt. If it is
- /// part of an MLFunction, then return it, otherwise return null.
- MLFunction *getContainingFunction();
- const MLFunction *getContainingFunction() const {
+ /// A StmtBlockList is part of a Function or and IfStmt/ForStmt. If it is
+ /// part of an Function, then return it, otherwise return null.
+ Function *getContainingFunction();
+ const Function *getContainingFunction() const {
return const_cast<StmtBlockList *>(this)->getContainingFunction();
}
+ // TODO(clattner): This is only to help ML -> CFG migration, remove in the
+ // near future. This makes StmtBlockList work more like BasicBlock did.
+ CFGFunction *getFunction();
+ const CFGFunction *getFunction() const {
+ return const_cast<StmtBlockList *>(this)->getFunction();
+ }
+
private:
BlockListType blocks;
/// This is the object we are part of.
- llvm::PointerUnion<MLFunction *, Statement *> container;
+ llvm::PointerUnion<Function *, Statement *> container;
};
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_STMTVISITOR_H
#define MLIR_IR_STMTVISITOR_H
-#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/Statements.h"
namespace mlir {
namespace mlir {
class Function;
-class CFGFunction;
-class MLFunction;
+using CFGFunction = Function;
+using MLFunction = Function;
class Module;
// Values that can be used by to signal success/failure. This can be implicitly
#ifndef MLIR_TRANSFORMS_CFGFUNCTIONVIEWGRAPH_H_
#define MLIR_TRANSFORMS_CFGFUNCTIONVIEWGRAPH_H_
-#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/Function.h"
#include "mlir/Pass.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/GraphWriter.h"
class AffineMap;
class ForStmt;
-class MLFunction;
+class Function;
+using MLFunction = Function;
class MLFuncBuilder;
// Values that can be used to signal success/failure. This can be implicitly
namespace mlir {
-class CFGFunction;
class ForStmt;
class FuncBuilder;
class Location;
class Module;
class OperationStmt;
class SSAValue;
+class Function;
+using CFGFunction = Function;
/// Replace all uses of oldMemRef with newMemRef while optionally remapping the
/// old memref's indices using the supplied affine map and adding any additional
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Dominance.h"
+#include "mlir/IR/Statements.h"
#include "llvm/Support/GenericDomTreeConstruction.h"
using namespace mlir;
/// Compute the immediate-dominators map.
DominanceInfo::DominanceInfo(CFGFunction *function) : DominatorTreeBase() {
// Build the dominator tree for the function.
- recalculate(*function);
+ recalculate(function->getBlockList());
}
/// Return true if instruction A properly dominates instruction B.
// limitations under the License.
// =============================================================================
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
PassResult PrintOpStatsPass::runOnCFGFunction(CFGFunction *function) {
for (const auto &bb : *function)
for (const auto &inst : bb)
- ++opCount[inst.getName().getStringRef()];
+ if (auto *op = dyn_cast<OperationInst>(&inst))
+ ++opCount[op->getName().getStringRef()];
return success();
}
//===----------------------------------------------------------------------===//
#include "mlir/Pass.h"
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/Support/PassNameParser.h"
#include "llvm/ADT/DenseMap.h"
}
PassResult FunctionPass::runOnFunction(Function *fn) {
- if (auto *mlFunc = dyn_cast<MLFunction>(fn))
- return runOnMLFunction(mlFunc);
- if (auto *cfgFunc = dyn_cast<CFGFunction>(fn))
- return runOnCFGFunction(cfgFunc);
+ if (fn->isML())
+ return runOnMLFunction(fn);
+ if (fn->isCFG())
+ return runOnCFGFunction(fn);
return success();
}
#include "mlir/Analysis/Dominance.h"
#include "mlir/IR/Attributes.h"
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
bool failure(const Twine &message, const BasicBlock &bb) {
// Take the location information for the first instruction in the block.
if (!bb.empty())
- return failure(message, bb.front());
+ if (auto *op = dyn_cast<OperationStmt>(&bb.front()))
+ return failure(message, *op);
// Worst case, fall back to using the function's location.
return failure(message, fn);
}
for (auto &inst : block) {
- if (verifyOperation(inst) || verifyInstOperands(inst))
+ if (auto *opInst = dyn_cast<OperationInst>(&inst))
+ if (verifyOperation(*opInst))
+ return true;
+
+ if (verifyInstOperands(inst))
return true;
}
return false;
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
-#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Statements.h"
// Visit functions.
void visitFunction(const Function *fn);
- void visitExtFunction(const ExtFunction *fn);
+ void visitExtFunction(const Function *fn);
void visitCFGFunction(const CFGFunction *fn);
void visitMLFunction(const MLFunction *fn);
void visitStatement(const Statement *stmt);
visitAttribute(elt.second);
}
-void ModuleState::visitExtFunction(const ExtFunction *fn) {
+void ModuleState::visitExtFunction(const Function *fn) {
visitType(fn->getType());
}
void ModuleState::visitCFGFunction(const CFGFunction *fn) {
visitType(fn->getType());
for (auto &block : *fn) {
- for (auto &op : block.getOperations()) {
- visitOperation(&op);
+ for (auto &op : block.getStatements()) {
+ if (auto *opInst = dyn_cast<OperationInst>(&op))
+ visitOperation(opInst);
+ else {
+ llvm_unreachable("IfStmt/ForStmt in a CFGFunction isn't supported");
+ }
}
}
}
void ModuleState::visitFunction(const Function *fn) {
switch (fn->getKind()) {
case Function::Kind::ExtFunc:
- return visitExtFunction(cast<ExtFunction>(fn));
+ return visitExtFunction(fn);
case Function::Kind::CFGFunc:
- return visitCFGFunction(cast<CFGFunction>(fn));
+ return visitCFGFunction(fn);
case Function::Kind::MLFunc:
- return visitMLFunction(cast<MLFunction>(fn));
+ return visitMLFunction(fn);
}
}
void printAttribute(Attribute attr);
void printType(Type type);
void print(const Function *fn);
- void print(const ExtFunction *fn);
- void print(const CFGFunction *fn);
- void print(const MLFunction *fn);
+ void printExt(const Function *fn);
+ void printCFG(const Function *fn);
+ void printML(const Function *fn);
void printAffineMap(AffineMap map);
void printAffineExpr(AffineExpr expr);
void ModulePrinter::print(const Function *fn) {
switch (fn->getKind()) {
case Function::Kind::ExtFunc:
- return print(cast<ExtFunction>(fn));
+ return printExt(fn);
case Function::Kind::CFGFunc:
- return print(cast<CFGFunction>(fn));
+ return printCFG(fn);
case Function::Kind::MLFunc:
- return print(cast<MLFunction>(fn));
+ return printML(fn);
}
}
os << '}';
}
-void ModulePrinter::print(const ExtFunction *fn) {
+void ModulePrinter::printExt(const Function *fn) {
os << "extfunc ";
printFunctionSignature(fn);
printFunctionAttributes(fn);
if (specialNameBuffer.empty()) {
switch (value->getKind()) {
- case SSAValueKind::BBArgument:
- // If this is an argument to the function, give it an 'arg' name.
- if (auto *bb = cast<BBArgument>(value)->getOwner())
- if (auto *fn = bb->getFunction())
- if (&fn->front() == bb) {
- specialName << "arg" << nextArgumentID++;
- break;
- }
- // Otherwise number it normally.
- valueIDs[value] = nextValueID++;
- return;
case SSAValueKind::BlockArgument:
// If this is an argument to the function, give it an 'arg' name.
if (auto *block = cast<BlockArgument>(value)->getOwner())
// Otherwise number it normally.
valueIDs[value] = nextValueID++;
return;
- case SSAValueKind::InstResult:
case SSAValueKind::StmtResult:
// This is an uninteresting result, give it a boring number and be
// done with it.
for (auto &op : *block) {
// We number instruction that have results, and we only number the first
// result.
- if (op.getNumResults() != 0)
- numberValueID(op.getResult(0));
+ if (auto *opInst = dyn_cast<OperationInst>(&op))
+ if (opInst->getNumResults() != 0)
+ numberValueID(opInst->getResult(0));
}
// Terminators do not define values.
}
os << '\n';
- for (auto &inst : block->getOperations()) {
+ for (auto &inst : block->getStatements()) {
os << " ";
print(&inst);
os << '\n';
os << "<<null instruction>>\n";
return;
}
- printOperation(inst);
+ auto *opInst = dyn_cast<OperationInst>(inst);
+ assert(opInst && "IfStmt/ForStmt aren't supported in CFG functions yet");
+ printOperation(opInst);
}
// Print the operands from "container" to "os", followed by a colon and their
printBranchOperands(term->getSuccessorOperands(index));
}
-void ModulePrinter::print(const CFGFunction *fn) {
+void ModulePrinter::printCFG(const Function *fn) {
CFGFunctionPrinter(fn, *this).print();
}
}
}
-void ModulePrinter::print(const MLFunction *fn) {
+void ModulePrinter::printML(const Function *fn) {
MLFunctionPrinter(fn, *this).print();
}
void SSAValue::print(raw_ostream &os) const {
switch (getKind()) {
- case SSAValueKind::BBArgument:
case SSAValueKind::BlockArgument:
// TODO: Improve this.
os << "<block argument>\n";
return;
- case SSAValueKind::InstResult:
- return getDefiningInst()->print(os);
case SSAValueKind::StmtResult:
return getDefiningStmt()->print(os);
case SSAValueKind::ForStmt:
void SSAValue::dump() const { print(llvm::errs()); }
void Instruction::print(raw_ostream &os) const {
- if (!getFunction()) {
+ auto *function = getFunction();
+ if (!function) {
os << "<<UNLINKED INSTRUCTION>>\n";
return;
}
- ModuleState state(getFunction()->getContext());
- ModulePrinter modulePrinter(os, state);
- CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
+ if (function->isCFG()) {
+ ModuleState state(function->getContext());
+ ModulePrinter modulePrinter(os, state);
+ CFGFunctionPrinter(function, modulePrinter).print(this);
+ } else {
+ ModuleState state(function->getContext());
+ ModulePrinter modulePrinter(os, state);
+ MLFunctionPrinter(function, modulePrinter).print(this);
+ }
}
void Instruction::dump() const {
}
void BasicBlock::print(raw_ostream &os) const {
- if (!getFunction()) {
+ auto *function = getFunction();
+ if (!function) {
os << "<<UNLINKED BLOCK>>\n";
return;
}
- ModuleState state(getFunction()->getContext());
- ModulePrinter modulePrinter(os, state);
- CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
+
+ if (function->isCFG()) {
+ ModuleState state(function->getContext());
+ ModulePrinter modulePrinter(os, state);
+ CFGFunctionPrinter(function, modulePrinter).print(this);
+ } else {
+ ModuleState state(function->getContext());
+ ModulePrinter modulePrinter(os, state);
+ MLFunctionPrinter(function, modulePrinter).print(this);
+ }
}
void BasicBlock::dump() const { print(llvm::errs()); }
/// Print out the name of the basic block without printing its body.
-void BasicBlock::printAsOperand(raw_ostream &os, bool printType) {
+void StmtBlock::printAsOperand(raw_ostream &os, bool printType) {
if (!getFunction()) {
os << "<<UNLINKED BLOCK>>\n";
return;
CFGFunctionPrinter(getFunction(), modulePrinter).printBBName(this);
}
-void Statement::print(raw_ostream &os) const {
- MLFunction *function = getFunction();
- if (!function) {
- os << "<<UNLINKED STATEMENT>>\n";
- return;
- }
-
- ModuleState state(function->getContext());
- ModulePrinter modulePrinter(os, state);
- MLFunctionPrinter(function, modulePrinter).print(this);
-}
-
-void Statement::dump() const { print(llvm::errs()); }
-
-void StmtBlock::printBlock(raw_ostream &os) const {
- const MLFunction *function = getFunction();
- ModuleState state(function->getContext());
- ModulePrinter modulePrinter(os, state);
- MLFunctionPrinter(function, modulePrinter).print(this);
-}
-
-void StmtBlock::dumpBlock() const { printBlock(llvm::errs()); }
-
void Function::print(raw_ostream &os) const {
ModuleState state(getContext());
ModulePrinter(os, state).print(this);
+++ /dev/null
-//===- BasicBlock.cpp - MLIR BasicBlock Class -----------------------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/IR/BasicBlock.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/CFGFunction.h"
-using namespace mlir;
-
-BasicBlock::BasicBlock() {}
-
-BasicBlock::~BasicBlock() {
- for (BBArgument *arg : arguments)
- delete arg;
- arguments.clear();
-}
-
-//===----------------------------------------------------------------------===//
-// Argument list management.
-//===----------------------------------------------------------------------===//
-
-BBArgument *BasicBlock::addArgument(Type type) {
- auto *arg = new BBArgument(type, this);
- arguments.push_back(arg);
- return arg;
-}
-
-/// Add one argument to the argument list for each type specified in the list.
-auto BasicBlock::addArguments(ArrayRef<Type> types)
- -> llvm::iterator_range<args_iterator> {
- arguments.reserve(arguments.size() + types.size());
- auto initialSize = arguments.size();
- for (auto type : types) {
- addArgument(type);
- }
- return {arguments.data() + initialSize, arguments.data() + arguments.size()};
-}
-
-void BasicBlock::eraseArgument(unsigned index) {
- assert(index < arguments.size());
-
- // Delete the argument.
- delete arguments[index];
- arguments.erase(arguments.begin() + index);
-
- // Erase this argument from each of the predecessor's terminator.
- for (auto predIt = pred_begin(), predE = pred_end(); predIt != predE;
- ++predIt) {
- auto *predTerminator = (*predIt)->getTerminator();
- predTerminator->eraseSuccessorOperand(predIt.getSuccessorIndex(), index);
- }
-}
-
-//===----------------------------------------------------------------------===//
-// Terminator management
-//===----------------------------------------------------------------------===//
-
-Instruction *BasicBlock::getTerminator() const {
- if (empty())
- return nullptr;
-
- // Check if the last instruction is a terminator.
- auto &backInst = operations.back();
- return backInst.isTerminator() ? const_cast<Instruction *>(&backInst)
- : nullptr;
-}
-
-/// Return true if this block has no predecessors.
-bool BasicBlock::hasNoPredecessors() const {
- return pred_begin() == pred_end();
-}
-
-/// If this basic block has exactly one predecessor, return it. Otherwise,
-/// return null.
-///
-/// Note that multiple edges from a single block (e.g. if you have a cond
-/// branch with the same block as the true/false destinations) is not
-/// considered to be a single predecessor.
-BasicBlock *BasicBlock::getSinglePredecessor() {
- auto it = pred_begin();
- if (it == pred_end())
- return nullptr;
- auto *firstPred = *it;
- ++it;
- return it == pred_end() ? firstPred : nullptr;
-}
-
-//===----------------------------------------------------------------------===//
-// ilist_traits for BasicBlock
-//===----------------------------------------------------------------------===//
-
-mlir::CFGFunction *
-llvm::ilist_traits<::mlir::BasicBlock>::getContainingFunction() {
- size_t Offset(
- size_t(&((CFGFunction *)nullptr->*CFGFunction::getSublistAccess(nullptr))));
- iplist<BasicBlock> *Anchor(static_cast<iplist<BasicBlock> *>(this));
- return reinterpret_cast<CFGFunction *>(reinterpret_cast<char *>(Anchor) -
- Offset);
-}
-
-/// This is a trait method invoked when a basic block is added to a function.
-/// We keep the function pointer up to date.
-void llvm::ilist_traits<::mlir::BasicBlock>::
-addNodeToList(BasicBlock *block) {
- assert(!block->function && "already in a function!");
- block->function = getContainingFunction();
-}
-
-/// This is a trait method invoked when an instruction is removed from a
-/// function. We keep the function pointer up to date.
-void llvm::ilist_traits<::mlir::BasicBlock>::
-removeNodeFromList(BasicBlock *block) {
- assert(block->function && "not already in a function!");
- block->function = nullptr;
-}
-
-/// This is a trait method invoked when an instruction is moved from one block
-/// to another. We keep the block pointer up to date.
-void llvm::ilist_traits<::mlir::BasicBlock>::
-transferNodesFromList(ilist_traits<BasicBlock> &otherList,
- block_iterator first, block_iterator last) {
- // If we are transferring instructions within the same function, the parent
- // pointer doesn't need to be updated.
- CFGFunction *curParent = getContainingFunction();
- if (curParent == otherList.getContainingFunction())
- return;
-
- // Update the 'function' member of each BasicBlock.
- for (; first != last; ++first)
- first->function = curParent;
-}
-
-//===----------------------------------------------------------------------===//
-// Manipulators
-//===----------------------------------------------------------------------===//
-
-/// Unlink this BasicBlock from its CFGFunction and delete it.
-void BasicBlock::eraseFromFunction() {
- assert(getFunction() && "BasicBlock has no parent");
- getFunction()->getBlocks().erase(this);
-}
-
-/// Split the basic block into two basic blocks before the specified
-/// instruction or iterator.
-///
-/// Note that all instructions BEFORE the specified iterator stay as part of
-/// the original basic block, an unconditional branch is added to the original
-/// block (going to the new block), and the rest of the instructions in the
-/// original block are moved to the new BB, including the old terminator. The
-/// newly formed BasicBlock is returned.
-///
-/// This function invalidates the specified iterator.
-BasicBlock *BasicBlock::splitBasicBlock(iterator splitBefore) {
- // Start by creating a new basic block, and insert it immediate after this
- // one in the containing function.
- auto newBB = new BasicBlock();
- getFunction()->getBlocks().insert(++CFGFunction::iterator(this), newBB);
- auto branchLoc =
- splitBefore == end() ? getTerminator()->getLoc() : splitBefore->getLoc();
-
- // Move all of the operations from the split point to the end of the function
- // into the new block.
- newBB->getOperations().splice(newBB->end(), getOperations(), splitBefore,
- end());
-
- // Create an unconditional branch to the new block, and move our terminator
- // to the new block.
- CFGFuncBuilder(this).create<BranchOp>(branchLoc, newBB);
- return newBB;
-}
}
/// Create an operation given the fields represented as an OperationState.
-Instruction *CFGFuncBuilder::createOperation(const OperationState &state) {
+OperationStmt *CFGFuncBuilder::createOperation(const OperationState &state) {
SmallVector<CFGValue *, 8> operands;
operands.reserve(state.operands.size());
// Allow null operands as they act as sentinal barriers between successor
// operand lists.
for (auto elt : state.operands)
- operands.push_back(elt ? cast<CFGValue>(elt) : nullptr);
+ operands.push_back(cast_or_null<CFGValue>(elt));
auto *op =
- Instruction::create(state.location, state.name, operands, state.types,
- state.attributes, state.successors, context);
- block->getOperations().insert(insertPoint, op);
+ OperationInst::create(state.location, state.name, operands, state.types,
+ state.attributes, state.successors, context);
+ block->getStatements().insert(insertPoint, op);
return op;
}
auto *op =
OperationStmt::create(state.location, state.name, operands, state.types,
- state.attributes, state.successorsS, context);
+ state.attributes, state.successors, context);
block->getStatements().insert(insertPoint, op);
return op;
}
bool BranchOp::verify() const {
// ML functions do not have branching terminators.
- if (!isa<Instruction>(getOperation()))
+ if (getOperation()->getOperationFunction()->isML())
return (emitOpError("cannot occur in a ML function"), true);
return false;
}
bool CondBranchOp::verify() const {
// ML functions do not have branching terminators.
- if (!isa<Instruction>(getOperation()))
+ if (getOperation()->getOperationFunction()->isML())
return (emitOpError("cannot occur in a ML function"), true);
if (!getCondition()->getType().isInteger(1))
return emitOpError("expected condition type was boolean (i1)");
}
bool ReturnOp::verify() const {
- const Function *function;
- if (auto *stmt = dyn_cast<OperationStmt>(getOperation()))
- function = stmt->getFunction();
- else
- function = cast<Instruction>(getOperation())->getFunction();
+ auto *function = cast<OperationStmt>(getOperation())->getFunction();
// The operand number and types must match the function signature.
const auto &results = function->getType().getResults();
// limitations under the License.
// =============================================================================
+#include "mlir/IR/Function.h"
#include "AttributeListStorage.h"
#include "mlir/IR/Attributes.h"
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLFunction.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StmtVisitor.h"
Function::Function(Kind kind, Location location, StringRef name,
FunctionType type, ArrayRef<NamedAttribute> attrs)
: nameAndKind(Identifier::get(name, type.getContext()), kind),
- location(location), type(type) {
+ location(location), type(type), blocks(this) {
this->attrs = AttributeListStorage::get(attrs, getContext());
+
+ // Creating of an MLFunction automatically populates the entry block and
+ // arguments.
+ // TODO(clattner): Unify this behavior.
+ if (kind == Kind::MLFunc) {
+ // The body of an MLFunction always has one block.
+ auto *entry = new StmtBlock();
+ blocks.push_back(entry);
+
+ // Initialize the arguments.
+ entry->addArguments(type.getInputs());
+ }
}
Function::~Function() {
+ // Instructions may have cyclic references, which need to be dropped before we
+ // can start deleting them.
+ for (auto &bb : *this)
+ for (auto &inst : bb)
+ inst.dropAllReferences();
+
// Clean up function attributes referring to this function.
FunctionAttr::dropFunctionReference(this);
}
MLIRContext *Function::getContext() const { return getType().getContext(); }
-/// Delete this object.
-void Function::destroy() {
- switch (getKind()) {
- case Kind::ExtFunc:
- delete cast<ExtFunction>(this);
- break;
- case Kind::MLFunc:
- delete cast<MLFunction>(this);
- break;
- case Kind::CFGFunc:
- delete cast<CFGFunction>(this);
- break;
- }
-}
-
Module *llvm::ilist_traits<Function>::getContainingModule() {
size_t Offset(
size_t(&((Module *)nullptr->*Module::getSublistAccess(nullptr))));
}
/// Unlink this function from its Module and delete it.
-void Function::eraseFromModule() {
+void Function::erase() {
assert(getModule() && "Function has no parent");
getModule()->getFunctions().erase(this);
}
bool Function::emitError(const Twine &message) const {
return getContext()->emitError(getLoc(), message);
}
-//===----------------------------------------------------------------------===//
-// ExtFunction implementation.
-//===----------------------------------------------------------------------===//
-
-ExtFunction::ExtFunction(Location location, StringRef name, FunctionType type,
- ArrayRef<NamedAttribute> attrs)
- : Function(Kind::ExtFunc, location, name, type, attrs) {}
-
-//===----------------------------------------------------------------------===//
-// CFGFunction implementation.
-//===----------------------------------------------------------------------===//
-
-CFGFunction::CFGFunction(Location location, StringRef name, FunctionType type,
- ArrayRef<NamedAttribute> attrs)
- : Function(Kind::CFGFunc, location, name, type, attrs) {}
-
-CFGFunction::~CFGFunction() {
- // Instructions may have cyclic references, which need to be dropped before we
- // can start deleting them.
- for (auto &bb : *this)
- for (auto &inst : bb)
- inst.dropAllReferences();
-}
//===----------------------------------------------------------------------===//
// MLFunction implementation.
//===----------------------------------------------------------------------===//
-MLFunction::MLFunction(Location location, StringRef name, FunctionType type,
- ArrayRef<NamedAttribute> attrs)
- : Function(Kind::MLFunc, location, name, type, attrs), body(this) {
-
- // The body of an MLFunction always has one block.
- auto *entry = new StmtBlock();
- body.push_back(entry);
-
- // Initialize the arguments.
- entry->addArguments(type.getInputs());
-}
-
-MLFunction::~MLFunction() {
- // Explicitly erase statements instead of relying of 'StmtBlock' destructor
- // since child statements need to be destroyed before function arguments
- // are destroyed.
- getBody()->clear();
-}
-
const OperationStmt *MLFunction::getReturnStmt() const {
return cast<OperationStmt>(&getBody()->back());
}
+++ /dev/null
-//===- Instructions.cpp - MLIR CFGFunction Instruction Classes ------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/IR/Instructions.h"
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLIRContext.h"
-using namespace mlir;
-
-/// Replace all uses of 'this' value with the new value, updating anything in
-/// the IR that uses 'this' to use the other value instead. When this returns
-/// there are zero uses of 'this'.
-void IRObjectWithUseList::replaceAllUsesWith(IRObjectWithUseList *newValue) {
- assert(this != newValue && "cannot RAUW a value with itself");
- while (!use_empty()) {
- use_begin()->set(newValue);
- }
-}
-
-/// Return the result number of this result.
-unsigned InstResult::getResultNumber() const {
- // Results are always stored consecutively, so use pointer subtraction to
- // figure out what number this is.
- return this - &getOwner()->getInstResults()[0];
-}
-
-/// Return which operand this is in the operand list.
-template <> unsigned InstOperand::getOperandNumber() const {
- return this - &getOwner()->getInstOperands()[0];
-}
-
-/// Return which operand this is in the operand list.
-template <> unsigned BasicBlockOperand::getOperandNumber() const {
- return this - &getOwner()->getBasicBlockOperands()[0];
-}
-
-//===----------------------------------------------------------------------===//
-// Instruction
-//===----------------------------------------------------------------------===//
-
-void Instruction::setSuccessor(BasicBlock *block, unsigned index) {
- assert(index < getNumSuccessors());
- getBasicBlockOperands()[index].set(block);
-}
-
-/// Create a new Instruction with the specified fields.
-Instruction *Instruction::create(Location location, OperationName name,
- ArrayRef<CFGValue *> operands,
- ArrayRef<Type> resultTypes,
- ArrayRef<NamedAttribute> attributes,
- ArrayRef<BasicBlock *> successors,
- MLIRContext *context) {
- unsigned numSuccessors = successors.size();
- // Input operands are nullptr-separated for each successors in the case of
- // terminators, the nullptr aren't actually stored.
- unsigned numOperands = operands.size() - llvm::count(operands, nullptr);
-
- auto byteSize =
- totalSizeToAlloc<InstResult, InstOperand, BasicBlockOperand, unsigned>(
- resultTypes.size(), numOperands, numSuccessors, numSuccessors);
- void *rawMem = malloc(byteSize);
-
- // Initialize the Instruction part of the instruction.
- auto inst = ::new (rawMem)
- Instruction(location, name, resultTypes.size(), numOperands,
- numSuccessors, attributes, context);
-
- // Initialize the results and operands.
- auto instResults = inst->getInstResults();
- for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
- new (&instResults[i]) InstResult(resultTypes[i], inst);
-
- // instOperandIt is the iterator in the tail-allocated memory for the
- // operands, and operandIt is the iterator in the input operands array.
- // operandIt skips nullptr in the input that acts as sentinels marking the
- // separation between multiple basic block successors for terminators.
- auto instOperandIt = inst->getInstOperands().begin();
- unsigned operandIt = 0, operandE = operands.size();
- for (; operandIt != operandE; ++operandIt) {
- if (!operands[operandIt])
- break;
- new (instOperandIt++) InstOperand(inst, operands[operandIt]);
- }
-
- // Check to see if a sentinel (nullptr) operand was encountered.
- unsigned currentSuccNum = 0;
- if (operandIt != operandE) {
- assert(inst->isTerminator() &&
- "Sentinel operand found in non terminator operand list.");
- auto instBlockOperands = inst->getBasicBlockOperands();
- unsigned *succOperandCountIt = inst->getTrailingObjects<unsigned>();
- unsigned *succOperandCountE = succOperandCountIt + numSuccessors;
- (void)succOperandCountE;
-
- for (; operandIt != operandE; ++operandIt) {
- // If we encounter a sentinal branch to the next operand update the count
- // variable.
- if (!operands[operandIt]) {
- assert(currentSuccNum < numSuccessors);
-
- // After the first iteration update the successor operand count
- // variable.
- if (currentSuccNum != 0) {
- ++succOperandCountIt;
- assert(succOperandCountIt != succOperandCountE &&
- "More sentinal operands than successors.");
- }
-
- new (&instBlockOperands[currentSuccNum])
- BasicBlockOperand(inst, successors[currentSuccNum]);
- *succOperandCountIt = 0;
- ++currentSuccNum;
- continue;
- }
- new (instOperandIt++) InstOperand(inst, operands[operandIt]);
- ++(*succOperandCountIt);
- }
- }
-
- // Verify that the amount of sentinal operands is equivalent to the number of
- // successors.
- assert(currentSuccNum == numSuccessors);
- return inst;
-}
-
-Instruction *Instruction::clone() const {
- SmallVector<CFGValue *, 8> operands;
- SmallVector<Type, 8> resultTypes;
- SmallVector<BasicBlock *, 1> successors;
-
- // Put together the results.
- for (auto *result : getResults())
- resultTypes.push_back(result->getType());
-
- // If the instruction is a terminator the successor and non-successor operand
- // lists are interleaved with sentinal(nullptr) operands.
- if (isTerminator()) {
- // To interleave the operand lists we iterate in reverse and insert the
- // operands in-place.
- operands.resize(getNumOperands() + getNumSuccessors());
- successors.resize(getNumSuccessors());
- int cloneOperandIt = operands.size() - 1, operandIt = getNumOperands() - 1;
- for (int succIt = getNumSuccessors() - 1, succE = 0; succIt >= succE;
- --succIt) {
- successors[succIt] = const_cast<BasicBlock *>(getSuccessor(succIt));
-
- // Add the successor operands in-place in reverse order.
- for (unsigned i = 0, e = getNumSuccessorOperands(succIt); i != e;
- ++i, --cloneOperandIt, --operandIt) {
- operands[cloneOperandIt] =
- const_cast<CFGValue *>(getOperand(operandIt));
- }
-
- // Add a null operand for the barrier.
- operands[cloneOperandIt--] = nullptr;
- }
-
- // Add the rest of the non-successor operands.
- for (; cloneOperandIt >= 0; --cloneOperandIt, --operandIt)
- operands[cloneOperandIt] = const_cast<CFGValue *>(getOperand(operandIt));
- // For non terminators we can simply add each of the instructions in place.
- } else {
- for (auto *operand : getOperands())
- operands.push_back(const_cast<CFGValue *>(operand));
- }
-
- return create(getLoc(), getName(), operands, resultTypes, getAttrs(),
- successors, getContext());
-}
-
-Instruction::Instruction(Location location, OperationName name,
- unsigned numResults, unsigned numOperands,
- unsigned numSuccessors,
- ArrayRef<NamedAttribute> attributes,
- MLIRContext *context)
- : Operation(/*isInstruction=*/true, name, attributes, context),
- IROperandOwner(IROperandOwner::Kind::Instruction, location),
- numResults(numResults), numSuccs(numSuccessors),
- numOperands(numOperands) {}
-
-Instruction::~Instruction() {
- assert(block == nullptr && "instruction destroyed but still in a block");
-
- // Explicitly run the destructors for the results
- for (auto &result : getInstResults())
- result.~InstResult();
-
- // Explicitly run the destructors for the operands.
- for (auto &result : getInstOperands())
- result.~InstOperand();
-
- // Explicitly run the destructors for the successors.
- if (isTerminator())
- for (auto &successor : getBasicBlockOperands())
- successor.~BasicBlockOperand();
-}
-
-void Instruction::eraseOperand(unsigned index) {
- assert(index < getNumOperands());
- auto Operands = getInstOperands();
- // Shift all operands down by 1.
- std::rotate(&Operands[index], &Operands[index + 1],
- &Operands[numOperands - 1]);
- --numOperands;
- Operands[getNumOperands()].~InstOperand();
-}
-
-/// Destroy this instruction.
-void Instruction::destroy() {
- this->~Instruction();
- free(this);
-}
-
-CFGFunction *Instruction::getFunction() {
- auto *block = getBlock();
- return block ? block->getFunction() : nullptr;
-}
-
-/// This drops all operand uses from this instruction, which is an essential
-/// step in breaking cyclic dependences between references when they are to
-/// be deleted.
-void Instruction::dropAllReferences() {
- for (auto &op : getInstOperands())
- op.drop();
-
- if (isTerminator())
- for (auto &dest : getBasicBlockOperands())
- dest.drop();
-}
-
-/// Emit a note about this instruction, reporting up to any diagnostic
-/// handlers that may be listening.
-void Instruction::emitNote(const Twine &message) const {
- getContext()->emitDiagnostic(getLoc(), message,
- MLIRContext::DiagnosticKind::Note);
-}
-
-/// Emit a warning about this operation, reporting up to any diagnostic
-/// handlers that may be listening.
-void Instruction::emitWarning(const Twine &message) const {
- getContext()->emitDiagnostic(getLoc(), message,
- MLIRContext::DiagnosticKind::Warning);
-}
-
-/// Emit an error about fatal conditions with this operation, reporting up to
-/// any diagnostic handlers that may be listening. This function always
-/// returns true. NOTE: This may terminate the containing application, only use
-/// when the IR is in an inconsistent state.
-bool Instruction::emitError(const Twine &message) const {
- return getContext()->emitError(getLoc(), message);
-}
-
-void llvm::ilist_traits<::mlir::Instruction>::deleteNode(Instruction *inst) {
- inst->destroy();
-}
-
-mlir::BasicBlock *
-llvm::ilist_traits<::mlir::Instruction>::getContainingBlock() {
- size_t Offset(
- size_t(&((BasicBlock *)nullptr->*BasicBlock::getSublistAccess(nullptr))));
- iplist<Instruction> *Anchor(static_cast<iplist<Instruction> *>(this));
- return reinterpret_cast<BasicBlock *>(reinterpret_cast<char *>(Anchor) -
- Offset);
-}
-
-/// This is a trait method invoked when an instruction is added to a block. We
-/// keep the block pointer up to date.
-void llvm::ilist_traits<::mlir::Instruction>::addNodeToList(Instruction *inst) {
- assert(!inst->getBlock() && "already in a basic block!");
- inst->block = getContainingBlock();
-}
-
-/// This is a trait method invoked when an instruction is removed from a block.
-/// We keep the block pointer up to date.
-void llvm::ilist_traits<::mlir::Instruction>::removeNodeFromList(
- Instruction *inst) {
- assert(inst->block && "not already in a basic block!");
- inst->block = nullptr;
-}
-
-/// This is a trait method invoked when an instruction is moved from one block
-/// to another. We keep the block pointer up to date.
-void llvm::ilist_traits<::mlir::Instruction>::transferNodesFromList(
- ilist_traits<Instruction> &otherList, instr_iterator first,
- instr_iterator last) {
- // If we are transferring instructions within the same basic block, the block
- // pointer doesn't need to be updated.
- BasicBlock *curParent = getContainingBlock();
- if (curParent == otherList.getContainingBlock())
- return;
-
- // Update the 'block' member of each instruction.
- for (; first != last; ++first)
- first->block = curParent;
-}
-
-/// Unlink this instruction from its BasicBlock and delete it.
-void Instruction::erase() {
- assert(getBlock() && "Instruction has no parent");
- getBlock()->getOperations().erase(this);
-}
-
-/// Unlink this operation instruction from its current basic block and insert
-/// it right before `existingInst` which may be in the same or another block
-/// in the same function.
-void Instruction::moveBefore(Instruction *existingInst) {
- assert(existingInst && "Cannot move before a null instruction");
- return moveBefore(existingInst->getBlock(), existingInst->getIterator());
-}
-
-/// Unlink this operation instruction from its current basic block and insert
-/// it right before `iterator` in the specified basic block.
-void Instruction::moveBefore(BasicBlock *block,
- llvm::iplist<Instruction>::iterator iterator) {
- block->getOperations().splice(iterator, getBlock()->getOperations(),
- getIterator());
-}
#include "mlir/IR/Operation.h"
#include "AttributeListStorage.h"
-#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/Dialect.h"
-#include "mlir/IR/Instructions.h"
-#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
/// Return the context this operation is associated with.
MLIRContext *Operation::getContext() const {
- if (auto *inst = llvm::dyn_cast<Instruction>(this))
- return inst->getContext();
return llvm::cast<OperationStmt>(this)->getContext();
}
/// The source location the operation was defined or derived from. Note that
/// it is possible for this pointer to be null.
Location Operation::getLoc() const {
- if (auto *inst = llvm::dyn_cast<Instruction>(this))
- return inst->getLoc();
return llvm::cast<OperationStmt>(this)->getLoc();
}
/// Set the source location the operation was defined or derived from.
void Operation::setLoc(Location loc) {
- if (auto *inst = llvm::dyn_cast<Instruction>(this))
- inst->setLoc(loc);
- else
- llvm::cast<OperationStmt>(this)->setLoc(loc);
+ llvm::cast<OperationStmt>(this)->setLoc(loc);
}
/// Return the function this operation is defined in.
Function *Operation::getOperationFunction() {
- if (auto *inst = llvm::dyn_cast<Instruction>(this))
- return inst->getFunction();
return llvm::cast<OperationStmt>(this)->getFunction();
}
/// Return the number of operands this operation has.
unsigned Operation::getNumOperands() const {
- if (auto *inst = llvm::dyn_cast<Instruction>(this))
- return inst->getNumOperands();
-
return llvm::cast<OperationStmt>(this)->getNumOperands();
}
SSAValue *Operation::getOperand(unsigned idx) {
- if (auto *inst = llvm::dyn_cast<Instruction>(this))
- return inst->getOperand(idx);
-
return llvm::cast<OperationStmt>(this)->getOperand(idx);
}
void Operation::setOperand(unsigned idx, SSAValue *value) {
- if (auto *inst = llvm::dyn_cast<Instruction>(this)) {
- inst->setOperand(idx, llvm::cast<CFGValue>(value));
- } else {
- auto *stmt = llvm::cast<OperationStmt>(this);
- stmt->setOperand(idx, llvm::cast<MLValue>(value));
- }
+ auto *stmt = llvm::cast<OperationStmt>(this);
+ stmt->setOperand(idx, llvm::cast<MLValue>(value));
}
/// Return the number of results this operation has.
unsigned Operation::getNumResults() const {
- if (auto *inst = llvm::dyn_cast<Instruction>(this))
- return inst->getNumResults();
-
return llvm::cast<OperationStmt>(this)->getNumResults();
}
/// Return the indicated result.
SSAValue *Operation::getResult(unsigned idx) {
- if (auto *inst = llvm::dyn_cast<Instruction>(this))
- return inst->getResult(idx);
-
return llvm::cast<OperationStmt>(this)->getResult(idx);
}
unsigned Operation::getNumSuccessors() const {
assert(isTerminator() && "Only terminators have successors.");
- if (auto *inst = llvm::dyn_cast<Instruction>(this))
- return inst->getNumSuccessors();
-
return llvm::cast<OperationStmt>(this)->getNumSuccessors();
}
unsigned Operation::getNumSuccessorOperands(unsigned index) const {
assert(isTerminator() && "Only terminators have successors.");
-
- if (auto *inst = llvm::dyn_cast<Instruction>(this))
- return inst->getNumSuccessorOperands(index);
-
return llvm::cast<OperationStmt>(this)->getNumSuccessorOperands(index);
}
BasicBlock *Operation::getSuccessor(unsigned index) {
- assert(isTerminator() && "Only terminators have successors.");
- assert(llvm::isa<Instruction>(this) &&
- "Only instructions have basic block successors.");
- return llvm::cast<Instruction>(this)->getSuccessor(index);
+ assert(isTerminator() && "Only terminators have successors");
+ return llvm::cast<OperationStmt>(this)->getSuccessor(index);
}
void Operation::setSuccessor(BasicBlock *block, unsigned index) {
- assert(isTerminator() && "Only terminators have successors.");
- assert(llvm::isa<Instruction>(this) &&
- "Only instructions have basic block successors.");
- llvm::cast<Instruction>(this)->setSuccessor(block, index);
+ assert(isTerminator() && "Only terminators have successors");
+ llvm::cast<OperationStmt>(this)->setSuccessor(block, index);
}
void Operation::eraseSuccessorOperand(unsigned succIndex, unsigned opIndex) {
- assert(isTerminator() && "Only terminators have successors.");
- assert(llvm::isa<Instruction>(this) && "Only instructions have successors.");
- return llvm::cast<Instruction>(this)->eraseSuccessorOperand(succIndex,
- opIndex);
+ assert(isTerminator() && "Only terminators have successors");
+ return llvm::cast<OperationStmt>(this)->eraseSuccessorOperand(succIndex,
+ opIndex);
}
auto Operation::getSuccessorOperands(unsigned index) const
-> llvm::iterator_range<const_operand_iterator> {
assert(isTerminator() && "Only terminators have successors.");
- assert(llvm::isa<Instruction>(this) && "Only instructions have successors.");
unsigned succOperandIndex =
- llvm::cast<Instruction>(this)->getSuccessorOperandIndex(index);
+ llvm::cast<OperationStmt>(this)->getSuccessorOperandIndex(index);
return {const_operand_iterator(this, succOperandIndex),
const_operand_iterator(this, succOperandIndex +
getNumSuccessorOperands(index))};
auto Operation::getSuccessorOperands(unsigned index)
-> llvm::iterator_range<operand_iterator> {
assert(isTerminator() && "Only terminators have successors.");
- assert(llvm::isa<Instruction>(this) && "Only instructions have successors.");
unsigned succOperandIndex =
- llvm::cast<Instruction>(this)->getSuccessorOperandIndex(index);
+ llvm::cast<OperationStmt>(this)->getSuccessorOperandIndex(index);
return {operand_iterator(this, succOperandIndex),
operand_iterator(this,
succOperandIndex + getNumSuccessorOperands(index))};
}
void Operation::moveBefore(Operation *existingOp) {
- if (auto *inst = llvm::dyn_cast<Instruction>(this))
- return inst->moveBefore(llvm::cast<Instruction>(existingOp));
return llvm::cast<OperationStmt>(this)->moveBefore(
llvm::cast<OperationStmt>(existingOp));
}
/// Remove this operation from its parent block and delete it.
void Operation::erase() {
- if (auto *inst = llvm::dyn_cast<Instruction>(this))
- return inst->erase();
return llvm::cast<OperationStmt>(this)->erase();
}
}
void Operation::print(raw_ostream &os) const {
- if (auto *inst = llvm::dyn_cast<Instruction>(this))
- return inst->print(os);
return llvm::cast<OperationStmt>(this)->print(os);
}
void Operation::dump() const {
- if (auto *inst = llvm::dyn_cast<Instruction>(this))
- return inst->dump();
return llvm::cast<OperationStmt>(this)->dump();
}
llvm::cast_convert_val<mlir::Operation, mlir::IROperandOwner *,
mlir::IROperandOwner *>::doit(const mlir::IROperandOwner
*value) {
+ // TODO(clattner): obsolete this.
const Operation *op;
- if (auto *ptr = dyn_cast<OperationStmt>(value))
- op = ptr;
- else
- op = cast<Instruction>(value);
+ auto *ptr = cast<OperationStmt>(value);
+ op = ptr;
return const_cast<Operation *>(op);
}
bool OpTrait::impl::verifyIsTerminator(const Operation *op) {
// Verify that the operation is at the end of the respective parent block.
- if (auto *stmt = dyn_cast<OperationStmt>(op)) {
+ if (op->getOperationFunction()->isML()) {
+ auto *stmt = cast<OperationStmt>(op);
StmtBlock *block = stmt->getBlock();
if (!block || block->getContainingStmt() || &block->back() != stmt)
return op->emitOpError("must be the last statement in the ML function");
} else {
- const Instruction *inst = cast<Instruction>(op);
+ auto *inst = cast<OperationInst>(op);
const BasicBlock *block = inst->getBlock();
if (!block || &block->back() != inst)
return op->emitOpError(
// =============================================================================
#include "mlir/IR/SSAValue.h"
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/Instructions.h"
-#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/Statements.h"
using namespace mlir;
/// If this value is the result of an Instruction, return the instruction
/// that defines it.
-Instruction *SSAValue::getDefiningInst() {
+OperationInst *SSAValue::getDefiningInst() {
if (auto *result = dyn_cast<InstResult>(this))
return result->getOwner();
return nullptr;
/// Return the function that this SSAValue is defined in.
Function *SSAValue::getFunction() {
switch (getKind()) {
- case SSAValueKind::BBArgument:
- return cast<BBArgument>(this)->getFunction();
- case SSAValueKind::InstResult:
- return getDefiningInst()->getFunction();
case SSAValueKind::BlockArgument:
return cast<BlockArgument>(this)->getFunction();
case SSAValueKind::StmtResult:
// IROperandOwner implementation.
//===----------------------------------------------------------------------===//
+/// Replace all uses of 'this' value with the new value, updating anything in
+/// the IR that uses 'this' to use the other value instead. When this returns
+/// there are zero uses of 'this'.
+void IRObjectWithUseList::replaceAllUsesWith(IRObjectWithUseList *newValue) {
+ assert(this != newValue && "cannot RAUW a value with itself");
+ while (!use_empty()) {
+ use_begin()->set(newValue);
+ }
+}
+
/// Return the context this operation is associated with.
MLIRContext *IROperandOwner::getContext() const {
switch (getKind()) {
}
}
-//===----------------------------------------------------------------------===//
-// CFGValue implementation.
-//===----------------------------------------------------------------------===//
-
-/// Return the function that this CFGValue is defined in.
-CFGFunction *CFGValue::getFunction() {
- return cast<CFGFunction>(static_cast<SSAValue *>(this)->getFunction());
-}
-
-//===----------------------------------------------------------------------===//
-// BBArgument implementation.
-//===----------------------------------------------------------------------===//
-
-/// Return the function that this argument is defined in.
-CFGFunction *BBArgument::getFunction() {
- if (auto *owner = getOwner())
- return owner->getFunction();
- return nullptr;
-}
-
//===----------------------------------------------------------------------===//
// MLValue implementation.
//===----------------------------------------------------------------------===//
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
-#include "mlir/IR/MLFunction.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
return getContext()->emitError(getLoc(), message);
}
+// Returns whether the Statement is a terminator.
+bool Statement::isTerminator() const {
+ if (auto *op = dyn_cast<OperationStmt>(this))
+ return op->isTerminator();
+ return false;
+}
+
//===----------------------------------------------------------------------===//
// ilist_traits for Statement
//===----------------------------------------------------------------------===//
getIterator());
}
+/// This drops all operand uses from this instruction, which is an essential
+/// step in breaking cyclic dependences between references when they are to
+/// be deleted.
+void Statement::dropAllReferences() {
+ for (auto &op : getStmtOperands())
+ op.drop();
+
+ if (isTerminator())
+ for (auto &dest : cast<OperationInst>(this)->getBlockOperands())
+ dest.drop();
+}
+
//===----------------------------------------------------------------------===//
// OperationStmt
//===----------------------------------------------------------------------===//
ArrayRef<StmtBlock *> successors,
MLIRContext *context) {
unsigned numSuccessors = successors.size();
+
+ // Input operands are nullptr-separated for each successors in the case of
+ // terminators, the nullptr aren't actually stored.
+ unsigned numOperands = operands.size() - numSuccessors;
+
auto byteSize =
totalSizeToAlloc<StmtResult, StmtBlockOperand, unsigned, StmtOperand>(
- resultTypes.size(), numSuccessors, numSuccessors, operands.size());
+ resultTypes.size(), numSuccessors, numSuccessors, numOperands);
void *rawMem = malloc(byteSize);
// Initialize the OperationStmt part of the statement.
auto stmt = ::new (rawMem)
- OperationStmt(location, name, operands.size(), resultTypes.size(),
+ OperationStmt(location, name, numOperands, resultTypes.size(),
numSuccessors, attributes, context);
// Initialize the results and operands.
// Verify that the amount of sentinal operands is equivalent to the number
// of successors.
assert(currentSuccNum == numSuccessors);
-
return stmt;
}
for (auto &result : getStmtResults())
result.~StmtResult();
+
+ // Explicitly run the destructors for the successors.
+ if (isTerminator())
+ for (auto &successor : getBlockOperands())
+ successor.~StmtBlockOperand();
}
void OperationStmt::destroy() {
bool OperationStmt::isReturn() const { return isa<ReturnOp>(); }
+void OperationStmt::setSuccessor(BasicBlock *block, unsigned index) {
+ assert(index < getNumSuccessors());
+ getBlockOperands()[index].set(block);
+}
+
+void OperationInst::eraseOperand(unsigned index) {
+ assert(index < getNumOperands());
+ auto Operands = getStmtOperands();
+ // Shift all operands down by 1.
+ std::rotate(&Operands[index], &Operands[index + 1],
+ &Operands[numOperands - 1]);
+ --numOperands;
+ Operands[getNumOperands()].~StmtOperand();
+}
+
//===----------------------------------------------------------------------===//
// ForStmt
//===----------------------------------------------------------------------===//
return newIf;
}
+
+Statement *Statement::clone(MLIRContext *context) const {
+ DenseMap<const MLValue *, MLValue *> operandMap;
+ return clone(operandMap, context);
+}
// =============================================================================
#include "mlir/IR/StmtBlock.h"
-#include "mlir/IR/MLFunction.h"
-#include "mlir/IR/Statements.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
using namespace mlir;
StmtBlock::~StmtBlock() {
return it == pred_end() ? firstPred : nullptr;
}
+//===----------------------------------------------------------------------===//
+// Other
+//===----------------------------------------------------------------------===//
+
+/// Unlink this BasicBlock from its CFGFunction and delete it.
+void BasicBlock::eraseFromFunction() {
+ assert(getFunction() && "BasicBlock has no parent");
+ getFunction()->getBlocks().erase(this);
+}
+
+/// Split the basic block into two basic blocks before the specified
+/// instruction or iterator.
+///
+/// Note that all instructions BEFORE the specified iterator stay as part of
+/// the original basic block, an unconditional branch is added to the original
+/// block (going to the new block), and the rest of the instructions in the
+/// original block are moved to the new BB, including the old terminator. The
+/// newly formed BasicBlock is returned.
+///
+/// This function invalidates the specified iterator.
+BasicBlock *BasicBlock::splitBasicBlock(iterator splitBefore) {
+ // Start by creating a new basic block, and insert it immediate after this
+ // one in the containing function.
+ auto newBB = new BasicBlock();
+ getFunction()->getBlocks().insert(++CFGFunction::iterator(this), newBB);
+ auto branchLoc =
+ splitBefore == end() ? getTerminator()->getLoc() : splitBefore->getLoc();
+
+ // Move all of the operations from the split point to the end of the function
+ // into the new block.
+ newBB->getStatements().splice(newBB->end(), getStatements(), splitBefore,
+ end());
+
+ // Create an unconditional branch to the new block, and move our terminator
+ // to the new block.
+ CFGFuncBuilder(this).create<BranchOp>(branchLoc, newBB);
+ return newBB;
+}
+
//===----------------------------------------------------------------------===//
// StmtBlockList
//===----------------------------------------------------------------------===//
-StmtBlockList::StmtBlockList(MLFunction *container) : container(container) {}
+StmtBlockList::StmtBlockList(Function *container) : container(container) {}
StmtBlockList::StmtBlockList(Statement *container) : container(container) {}
+CFGFunction *StmtBlockList::getFunction() {
+ return dyn_cast_or_null<CFGFunction>(getContainingFunction());
+}
+
Statement *StmtBlockList::getContainingStmt() {
return container.dyn_cast<Statement *>();
}
-MLFunction *StmtBlockList::getContainingFunction() {
- return container.dyn_cast<MLFunction *>();
+Function *StmtBlockList::getContainingFunction() {
+ return container.dyn_cast<Function *>();
}
StmtBlockList *llvm::ilist_traits<::mlir::StmtBlock>::getContainingBlockList() {
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h"
-#include "mlir/IR/MLFunction.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
~ParserState() {
// Destroy the forward references upon error.
for (auto forwardRef : functionForwardRefs)
- forwardRef.second->destroy();
+ delete forwardRef.second;
functionForwardRefs.clear();
}
if (!function) {
auto &entry = state.functionForwardRefs[name];
if (!entry)
- entry = new ExtFunction(getEncodedSourceLocation(nameLoc), name, type,
- /*attrs=*/{});
+ entry = new Function(Function::Kind::ExtFunc,
+ getEncodedSourceLocation(nameLoc), name, type,
+ /*attrs=*/{});
function = entry;
}
// cannot be created through normal user input, allowing us to distinguish
// them.
auto name = OperationName("placeholder", getContext());
- auto *inst = Instruction::create(getEncodedSourceLocation(loc), name,
- /*operands=*/{}, type,
- /*attributes=*/{},
- /*successors=*/{}, getContext());
+ auto *inst = OperationInst::create(getEncodedSourceLocation(loc), name,
+ /*operands=*/{}, type,
+ /*attributes=*/{},
+ /*successors=*/{}, getContext());
forwardReferencePlaceholders[inst->getResult(0)] = loc;
return inst->getResult(0);
}
// Okay, the external function definition was parsed correctly.
auto *function =
- new ExtFunction(getEncodedSourceLocation(loc), name, type, attrs);
+ new Function(Function::Kind::ExtFunc, getEncodedSourceLocation(loc), name,
+ type, attrs);
getModule()->getFunctions().push_back(function);
// Verify no name collision / redefinition.
// Okay, the CFG function signature was parsed correctly, create the
// function.
auto *function =
- new CFGFunction(getEncodedSourceLocation(loc), name, type, attrs);
+ new Function(Function::Kind::CFGFunc, getEncodedSourceLocation(loc), name,
+ type, attrs);
getModule()->getFunctions().push_back(function);
// Verify no name collision / redefinition.
// Okay, the ML function signature was parsed correctly, create the
// function.
- auto *function =
- new MLFunction(getEncodedSourceLocation(loc), name, type, attrs);
+ auto *function = new Function(
+ Function::Kind::MLFunc, getEncodedSourceLocation(loc), name, type, attrs);
getModule()->getFunctions().push_back(function);
// Verify no name collision / redefinition.
// Now that all references to the forward definition placeholders are
// resolved, we can deallocate the placeholders.
for (auto forwardRef : getState().functionForwardRefs)
- forwardRef.second->destroy();
+ delete forwardRef.second;
getState().functionForwardRefs.clear();
return ParseSuccess;
}
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
+#include "mlir/IR/Statements.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/SuperVectorOps/SuperVectorOps.h"
#include "mlir/Support/FileUtilities.h"
bool convertBasicBlock(const BasicBlock &bb, bool ignoreArguments = false);
bool convertCFGFunction(const CFGFunction &cfgFunc, llvm::Function &llvmFunc);
bool convertFunctions(const Module &mlirModule, llvm::Module &llvmModule);
- bool convertInstruction(const Instruction &inst);
+ bool convertInstruction(const OperationInst &inst);
void connectPHINodes(const CFGFunction &cfgFunc);
// FIXME(zinenko): this should eventually become a separate MLIR pass that
// converts MLIR standard operations into LLVM IR dialect; the translation in
// that case would become a simple 1:1 instruction and value remapping.
-bool ModuleLowerer::convertInstruction(const Instruction &inst) {
+bool ModuleLowerer::convertInstruction(const OperationInst &inst) {
if (auto op = inst.dyn_cast<AddIOp>())
return valueMapping[op->getResult()] =
builder.CreateAdd(valueMapping[op->getOperand(0)],
// Traverse instructions.
for (const auto &inst : bb) {
- if (convertInstruction(inst))
+ auto *op = dyn_cast<OperationInst>(&inst);
+ if (!op)
+ return inst.emitError("unsupported operation");
+
+ if (convertInstruction(*op))
return true;
}
const BasicBlock *pred,
unsigned numArguments,
unsigned index) {
- const Instruction &terminator = *pred->getTerminator();
+ auto &terminator = *pred->getTerminator();
if (terminator.isa<BranchOp>()) {
return terminator.getOperand(index);
}
// call graph with cycles. We don't expect MLFunctions here.
for (const Function &function : mlirModule) {
const Function *functionPtr = &function;
- if (!isa<ExtFunction>(functionPtr) && !isa<CFGFunction>(functionPtr))
+ if (functionPtr->isML())
continue;
llvm::Constant *llvmFuncCst = llvmModule.getOrInsertFunction(
function.getName(), convertFunctionType(function.getType()));
// Convert CFG functions.
for (const Function &function : mlirModule) {
const Function *functionPtr = &function;
- auto cfgFunction = dyn_cast<CFGFunction>(functionPtr);
- if (!cfgFunction)
+ if (!functionPtr->isCFG())
continue;
- llvm::Function *llvmFunc = functionMapping[cfgFunction];
+ llvm::Function *llvmFunc = functionMapping[functionPtr];
// Add function arguments to the value remapping table. In CFGFunction,
// arguments of the first block are those of the function.
- assert(!cfgFunction->getBlocks().empty() &&
+ assert(!functionPtr->getBlocks().empty() &&
"expected at least one basic block in a CFGFunction");
- const BasicBlock &firstBlock = *cfgFunction->begin();
+ const BasicBlock &firstBlock = *functionPtr->begin();
for (auto arg : llvm::enumerate(llvmFunc->args())) {
valueMapping[firstBlock.getArgument(arg.index())] = &arg.value();
}
- if (convertCFGFunction(*cfgFunction, *functionMapping[cfgFunction]))
+ if (convertCFGFunction(*functionPtr, *functionMapping[functionPtr]))
return true;
}
return false;
#include "mlir/Analysis/Dominance.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Support/Functional.h"
void simplifyBasicBlock(BasicBlock *bb) {
for (auto &i : *bb)
- simplifyOperation(&i);
+ if (auto *opInst = dyn_cast<OperationInst>(&i))
+ simplifyOperation(opInst);
}
};
// =============================================================================
#include "mlir/IR/Builders.h"
-#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/Passes.h"
for (auto &bb : *f) {
for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) {
- auto &inst = *instIt++;
+ auto *inst = dyn_cast<OperationInst>(&*instIt++);
+ if (!inst)
+ continue;
auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
- builder.setInsertionPoint(&inst);
- return builder.create<ConstantOp>(inst.getLoc(), value, type);
+ builder.setInsertionPoint(inst);
+ return builder.create<ConstantOp>(inst->getLoc(), value, type);
};
- if (!foldOperation(&inst, existingConstants, constantFactory)) {
+ if (!foldOperation(inst, existingConstants, constantFactory)) {
// At this point the operation is dead, remove it.
// TODO: This is assuming that all constant foldable operations have no
// side effects. When we have side effect modeling, we should verify
// that the operation is effect-free before we remove it. Until then
// this is close enough.
- inst.erase();
+ inst->erase();
}
}
}
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLFunction.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StmtVisitor.h"
void ModuleConverter::convertMLFunctions() {
for (Function &fn : *module) {
- if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
- generatedFuncs[mlFunc] = convert(mlFunc);
+ if (fn.isML())
+ generatedFuncs[&fn] = convert(&fn);
}
}
// Use the same name as for ML function; do not add the converted function to
// the module yet to avoid collision.
auto name = mlFunc->getName().str();
- auto *cfgFunc = new CFGFunction(mlFunc->getLoc(), name, mlFunc->getType(),
- mlFunc->getAttrs());
+ auto *cfgFunc = new Function(Function::Kind::CFGFunc, mlFunc->getLoc(), name,
+ mlFunc->getType(), mlFunc->getAttrs());
// Generates the body of the CFG function.
return FunctionConverter(cfgFunc).convert(mlFunc);
// functions.
llvm::DenseMap<Attribute, FunctionAttr> remappingTable;
for (const Function &fn : *module) {
- const auto *mlFunc = dyn_cast<MLFunction>(&fn);
- if (!mlFunc)
+ if (!fn.isML())
continue;
- CFGFunction *convertedFunc = generatedFuncs.lookup(mlFunc);
+ CFGFunction *convertedFunc = generatedFuncs.lookup(&fn);
assert(convertedFunc && "ML function was not converted");
MLIRContext *context = module->getContext();
- auto mlFuncAttr = FunctionAttr::get(mlFunc, context);
+ auto mlFuncAttr = FunctionAttr::get(&fn, context);
auto cfgFuncAttr = FunctionAttr::get(convertedFunc, module->getContext());
remappingTable.insert({mlFuncAttr, cfgFuncAttr});
}
static inline void replaceMLFunctionAttr(
Operation &op, Identifier name, const Function *func,
const llvm::DenseMap<MLFunction *, CFGFunction *> &generatedFuncs) {
- const auto *mlFunc = dyn_cast<MLFunction>(func);
- if (!mlFunc)
+ if (!func->isML())
return;
Builder b(op.getContext());
- auto cfgFunc = generatedFuncs.lookup(mlFunc);
+ auto *cfgFunc = generatedFuncs.lookup(func);
op.setAttr(name, b.getFunctionAttr(cfgFunc));
}
//
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/CFGFunction.h"
-#include "mlir/IR/MLFunction.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/LoweringUtils.h"
#include "mlir/Transforms/Passes.h"
// Handle iterators with care because we erase in the same loop.
// In particular, step to the next element before erasing the current one.
for (auto it = bb.begin(); it != bb.end();) {
- Instruction &inst = *it;
- OpPointer<AffineApplyOp> affineApplyOp = inst.dyn_cast<AffineApplyOp>();
- ++it;
+ auto *inst = dyn_cast<OperationInst>(&*it++);
+ if (!inst)
+ continue;
+ auto affineApplyOp = inst->dyn_cast<AffineApplyOp>();
if (!affineApplyOp)
continue;
if (expandAffineApply(&*affineApplyOp))
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/AffineStructures.h"
-#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/Passes.h"
// TODO: If we make terminators into Operations then we could turn this
// into a nice Operation::moveBefore(Operation*) method. We just need the
// guarantee that a block is non-empty.
- if (auto *cfgFunc = dyn_cast<CFGFunction>(currentFunction)) {
- auto &entryBB = cfgFunc->front();
- cast<Instruction>(op)->moveBefore(&entryBB, entryBB.begin());
+ // TODO(clattner): This can all be simplified away now.
+ if (currentFunction->isCFG()) {
+ auto &entryBB = currentFunction->front();
+ cast<OperationInst>(op)->moveBefore(&entryBB, entryBB.begin());
} else {
auto *mlFunc = cast<MLFunction>(currentFunction);
cast<OperationStmt>(op)->moveBefore(mlFunc->getBody(),
void setInsertionPoint(Operation *op) override {
// Any new operations should be added before this instruction.
- builder.setInsertionPoint(cast<Instruction>(op));
+ builder.setInsertionPoint(cast<OperationInst>(op));
}
private:
GreedyPatternRewriteDriver driver(std::move(patterns));
for (auto &bb : *fn)
for (auto &op : bb)
- driver.addToWorklist(&op);
+ if (auto *opInst = dyn_cast<OperationStmt>(&op))
+ driver.addToWorklist(opInst);
CFGFuncBuilder cfgBuilder(fn);
CFGFuncRewriter rewriter(driver, cfgBuilder);
///
void mlir::applyPatternsGreedily(Function *fn,
OwningRewritePatternList &&patterns) {
- if (auto *cfg = dyn_cast<CFGFunction>(fn)) {
- processCFGFunction(cfg, std::move(patterns));
- } else {
- processMLFunction(cast<MLFunction>(fn), std::move(patterns));
- }
+ if (fn->isCFG())
+ processCFGFunction(fn, std::move(patterns));
+ else if (fn->isML())
+ processMLFunction(fn, std::move(patterns));
}
void mlir::remapFunctionAttrs(
Function &fn, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
// Look at all instructions in a CFGFunction.
- if (auto *cfgFn = dyn_cast<CFGFunction>(&fn)) {
- for (auto &bb : *cfgFn) {
+ if (fn.isCFG()) {
+ for (auto &bb : fn.getBlockList()) {
for (auto &inst : bb) {
- remapFunctionAttrs(inst, remappingTable);
+ if (auto *op = dyn_cast<OperationInst>(&inst))
+ remapFunctionAttrs(*op, remappingTable);
}
}
return;
}
- // Otherwise, look at MLFunctions. We ignore ExtFunctions.
- auto *mlFn = dyn_cast<MLFunction>(&fn);
- if (!mlFn)
+ // Otherwise, look at MLFunctions. We ignore external functions.
+ if (!fn.isML())
return;
struct MLFnWalker : public StmtWalker<MLFnWalker> {
const DenseMap<Attribute, FunctionAttr> &remappingTable;
};
- MLFnWalker(remappingTable).walk(mlFn);
+ MLFnWalker(remappingTable).walk(&fn);
}
void mlir::remapFunctionAttrs(
#include "mlir/Analysis/Passes.h"
#include "mlir/IR/Attributes.h"
-#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/Location.h"
-#include "mlir/IR/MLFunction.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/Parser.h"