* \return The corresponding name.
*/
Doc AllocTypeVar(const TypeVar& var) {
+ if (memo_type_.count(var)) {
+ Doc val = memo_type_[var];
+ val << "-malformed-ir";
+ return val;
+ }
std::string name = var->var->name_hint;
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "t" + name;
}
Doc val = GetUniqueName("%" + name);
- if (memo_type_.count(var)) {
- val << "-malformed-ir";
- }
memo_type_[var] = val;
if (var->kind != kType) {
val << ": " << Print(var->kind);
* \return The corresponding name.
*/
Doc AllocVar(const Var& var) {
+ // still print if ir is malformed, but show the error.
+ if (memo_.count(var)) {
+ Doc val = memo_[var];
+ val << "-malformed-ir";
+ return val;
+ }
std::string name = var->name_hint();
// always make sure first name is alpha
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "v" + name;
}
Doc val = GetUniqueName("%" + name);
- // still print if ir is malformed, but show the error.
- if (memo_.count(var)) {
- val << "-malformed-ir";
- }
memo_[var] = val;
if (var->type_annotation.defined()) {
val << ": " << Print(var->type_annotation);
*/
/*!
- * Copyright (c) 2018 by Contributors
+ * Copyright (c) 2019 by Contributors
*
* \file src/tvm/relay/pass/fuse_ops.cc
*
node->pattern = op_pattern;
this->Update(call->op, nullptr, kOpaque);
const auto* rtype = call->checked_type().as<TensorTypeNode>();
- // pass the message back to all the children it references.
+ // pass the analysis back to all the children it references.
for (size_t i = 0; i < call->args.size(); ++i) {
const auto* arg_type =
call->args[i]->checked_type().as<TensorTypeNode>();
- // specifically check if result type
+ // specifically check if result type is the same as arguments type
OpPatternKind edge_pattern = op_pattern;
if (edge_pattern == kBroadcast &&
arg_type != nullptr &&
return rhs;
}
/*!
- * \brief Find the least common acenstor of the two nodes.
+ * \brief Find the least common ancestor of the two nodes.
* \param lhs The left node.
* \param rhs The right node.
* \param edge_pattern
* The combined edge pattern across all the parents.
- * \return The least common ancestor of thw two.
+ * \return The least common ancestor of the two.
*/
static Node* LeastCommonAncestor(
Node* lhs,
}
return lhs;
}
-};
-
-DominatorTree DominatorTree::PostDom(common::Arena* arena,
- const IndexedForwardGraph& graph) {
- DominatorTree tree;
- tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
- // reverse topo order
- for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
- size_t index = i - 1;
+ /*!
+ * \brief Find the least common ancestor of a list of nodes.
+ * \param nodes the nodes.
+ * \param edge_pattern
+ * The combined edge pattern across all the parents.
+ * \return The least common ancestor of all nodes.
+ */
+ Node* LeastCommonAncestor(const LinkedList<IndexedForwardGraph::Edge>& input_nodes,
+ OpPatternKind* edge_pattern) {
+ auto link = input_nodes.head;
+ if (link == nullptr) {
+ return nullptr;
+ }
+ auto get_node = [&](const IndexedForwardGraph::Edge& edge) {
+ size_t oindex = edge.node->index;
+ CHECK_LT(oindex, nodes.size());
+ Node* onode = nodes[oindex];
+ CHECK(onode != nullptr);
+ return onode;
+ };
+ Node* parent = get_node(link->value);
+ *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
+ link = link->next;
+ for (; link != nullptr; link = link->next) {
+ parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern);
+ *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
+ }
+ return parent;
+ }
+ /*!
+ * \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node.
+ * \param arena The Arena.
+ * \param gnode An IndexedForwardGraph Node.
+ * \return The DominatorTree Node.
+ */
+ Node* GetNode(common::Arena* arena, IndexedForwardGraph::Node* gnode) {
Node* tnode = arena->make<Node>();
- auto* gnode = graph.post_dfs_order[index];
tnode->gnode = gnode;
if (gnode->extern_ref) {
tnode->depth = 1;
} else {
// find the LCAs of all outputs.
OpPatternKind pattern = kElemWise;
- Node* parent = nullptr;
- for (auto link = gnode->outputs.head; link != nullptr; link= link->next) {
- size_t oindex = link->value.node->index;
- CHECK_LT(oindex, tree.nodes.size());
- Node* onode = tree.nodes[oindex];
- CHECK(onode != nullptr);
- if (parent != nullptr) {
- parent = LeastCommonAncestor(parent, onode, &pattern);
- } else {
- parent = onode;
- }
- pattern = CombinePattern(pattern, link->value.pattern);
- }
+ Node* parent = LeastCommonAncestor(gnode->outputs, &pattern);
tnode->depth = parent ? parent->depth + 1 : 1;
tnode->parent = parent;
tnode->pattern = pattern;
}
- tree.nodes[index] = tnode;
+ return tnode;
+ }
+};
+
+
+DominatorTree DominatorTree::PostDom(common::Arena* arena,
+ const IndexedForwardGraph& graph) {
+ DominatorTree tree;
+ tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
+ // reverse topo order
+ for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
+ size_t index = i - 1;
+ tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]);
}
return tree;
}
// merge the current group to the parent if possible.
MergeFromTo(gnode, target);
for (auto link = src->outputs.head; link != nullptr; link = link->next) {
- CommitFuse_(link->value.node, sink, target);;
+ CommitFuse_(link->value.node, sink, target);
}
}
/*!
Expr VisitExpr_(const TupleNode* tuple) {
auto* ret_group = gmap_.at(tuple)->FindRoot();
- if (ret_group == gmap_.at(tuple)) {
+ if (ret_group->root_ref == tuple) {
return ExprMutator::VisitExpr_(tuple);
}
// This tuple is an intermediate node in the group
auto* ret_group = gmap_.at(tuple_get)->FindRoot();
auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0];
auto new_node = TupleGetItemNode::make(new_tuple, tuple_get->index);
- if (ret_group == gmap_.at(tuple_get)) {
+ if (ret_group->root_ref == tuple_get) {
if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) {
// Isolated. This case occurs when tuple is created by an Opaque op
// e.g. multibox_transform_loc
}
};
-// Temporary solution, should be handled by implementing a "FunctionPass"
-// which applies fusion to each function.
-struct GlobalVarLiveness : ExprVisitor {
- Module module;
- std::set<GlobalVar> visited;
-
- explicit GlobalVarLiveness(const Module& mod) : module(mod), visited() {}
-
- void VisitExpr_(const GlobalVarNode* gvar_node) {
- auto gvar = GetRef<GlobalVar>(gvar_node);
- if (visited.find(gvar) == visited.end()) {
- visited.insert(gvar);
- this->VisitExpr(this->module->Lookup(gvar));
- }
- }
-};
-
-std::set<GlobalVar> LiveGlobals(const Module& mod, const Expr& expr) {
- auto gvl = GlobalVarLiveness(mod);
- gvl.VisitExpr(expr);
- return gvl.visited;
-}
-
Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) {
- // First we convert all chains of fusable ops into
- // abstracted functions which we mark as primtive
- // then we convert these primtive functions into
- // new operators.
- if (!module.defined()) {
- return FuseMutator().Transform(expr, fuse_opt_level);
- } else {
- auto lgvs = LiveGlobals(module, expr);
- for (auto lv : lgvs) {
- auto body = module->Lookup(lv);
- auto e = FuseMutator().Transform(body, fuse_opt_level);
- module->Add(lv, Downcast<Function>(e), true);
- }
- return FuseMutator().Transform(expr, fuse_opt_level);
- }
+ return FuseMutator().Transform(expr, fuse_opt_level);
}
namespace transform {