TVM_DLL bool ConstantCheck(const Expr& e);
/*!
+ * \brief Check whether an expression is in the basic block normal form.
+ *
+ * \param e the expression.
+ *
+ * \return whether the expression is in the basic block normal form.
+ */
+TVM_DLL bool BasicBlockNormalFormCheck(const Expr& e);
+
+/*!
* \brief Check that each Var is only bound once.
*
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
TVM_DLL Pass RewriteAnnotatedOps(int fallback_device);
/*!
+ * \brief Turn an expression to Basic Block Normal Form.
+ *
+ * We define a block as a group of expressions implied by the scope structure.
+ *
+ * Each graph node can only belong to a single block.
+ *
+ * For any value that is being used in multiple blocks, it has to be referred
+ * by a Var which is defined in a block, whose scope is the least common ancestor
+ * of blocks this value is used.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass ToBasicBlockNormalForm();
+
+/*!
* \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
* It will turn an expression that is in a graph form (with sharing implicit),
"""
return _ffi_api.check_constant(expr)
+def check_basic_block_normal_form(expr):
+ """Check whether an expression is in the basic block form
+
+ Parameters
+ ----------
+ expr : tvm.relay.Expr
+ The input expression
+
+ Returns
+ -------
+ result : bool
+ Whether the expression is in the basic block form.
+ """
+ return _ffi_api.check_basic_block_normal_form(expr)
+
def free_vars(expr):
"""Get free Vars from expression expr in Post DFS order.
"""
return _ffi_api.ToANormalForm()
+def ToBasicBlockNormalForm():
+ """Turn an expression to Basic Block Normal Form.
+ We define a block as a group of expressions implied by the scope structure.
+ Each graph node can only belong to a single block.
+ For any value that is being used in multiple blocks, it has to be referred
+ by a Var which is defined in a block, whose scope is the least common ancestor
+ of blocks this value is used.
+
+ Returns
+ -------
+ ret: tvm.transform.Pass
+ The registered pass that transforms an expression into Basic Block Normal Form.
+ """
+ return _ffi_api.ToBasicBlockNormalForm()
+
def ToCPS(expr, mod=None):
"""
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(f)];
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
+ for (const auto& p : f->params) {
+ Depend(b, p);
+ }
Depend(b, f->body);
graph_.post_dfs_order.push_back(b);
}
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(l)];
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
+ Depend(b, l->var);
Depend(b, l->value);
Depend(b, l->body);
graph_.post_dfs_order.push_back(b);
Array<Pass> pass_seqs;
Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
+ pass_seqs.push_back(transform::ToBasicBlockNormalForm());
// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());
Array<Pass> pass_seqs;
Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
+ pass_seqs.push_back(transform::ToBasicBlockNormalForm());
// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());
return ret;
}
+ /*! \brief get the number of let bindings in the let list.
+ *
+ * \return the let list size.
+ */
+ size_t size() const { return lets_.size(); }
+
/*! \brief generate an LetList and wrap the result automatically.
*
* \param f a function that generate the unwrapped Expr.
#include <memory>
#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+
+#include "../analysis/dependency_graph.h"
+#include "let_list.h"
namespace tvm {
namespace relay {
~TreeBranchNode() {}
};
+struct ScopeNode;
+using Scope = std::shared_ptr<ScopeNode>;
+using NodeScopeMap = std::unordered_map<DependencyGraph::Node*, Scope>;
+using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;
+
+/* Invariant: when parent is null level is 0
+ * Invariant: when parent is not null level is 1 + parent->level
+ */
+struct ScopeNode {
+ // the level of the scope
+ size_t level;
+ // the parent scope
+ Scope parent;
+ // the corresponding let list which holds all let bindings in the scope
+ std::shared_ptr<LetList> let_list = std::make_shared<LetList>();
+ explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) {}
+ ScopeNode() : level(0) {}
+};
+
+/*! \brief Calculate the scope of nodes in the dependency graph by least common ancestor.
+ *
+ * \param dg the input dependency graph
+ * \param expr_scope the output node -> scope mapping for all nodes.
+ * \param lifted_exprs the output set of expressions whose scope is lifted due to dependency
+ */
+std::pair<NodeScopeMap, ExprSet> CalcScope(const DependencyGraph& dg);
+
+/*! \brief find the least common ancestor of lhs scope and rhs scope.
+ */
+Scope LCA(Scope lhs, Scope rhs);
+
+/* Special care is needed to handle local recursion.
+ * Fill additionally take a (possibly null) Var argument,
+ * If it is not null, Fill is required to bind the transformed result to that var.
+ */
+class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
+ public:
+ static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg, NodeScopeMap* node_scope);
+
+ // For basic block normal form, bind expressions only if the original expression's
+ // scope should be lifted
+ static Expr ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg,
+ NodeScopeMap* node_scope, ExprSet* lifted);
+
+ private:
+ const DependencyGraph& dg_;
+ NodeScopeMap* node_scope_ = nullptr;
+ std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> memo;
+ // a set of Expressions to include for let bindings. If set to nullptr
+ // all Exprs will be pushed to the let list.
+ ExprSet* include_set_ = nullptr;
+
+ Fill(const DependencyGraph& dg, NodeScopeMap* node_scope, ExprSet* include_set)
+ : dg_(dg), node_scope_(node_scope), include_set_(include_set) {}
+
+ Scope GetScope(const Expr& e);
+ Scope GetSubScope(const Expr& e, size_t i);
+
+ Expr VisitExpr(const Expr& e, const Var& v) final;
+ Expr VisitExpr(const Expr& e);
+
+ Expr Atomic(const Expr& e, const Var& v);
+ // Bind expression `now` to var `v` if the original expression is in the include set, or if
+ // v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly.
+ Expr Compound(const Expr& orig, const Expr& now, const Var& v);
+
+ Expr VisitExpr_(const CallNode* c, const Var& v) final;
+ Expr VisitExpr_(const TupleNode* t, const Var& v) final;
+ Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final;
+ Expr VisitExpr_(const RefCreateNode* r, const Var& v) final;
+ Expr VisitExpr_(const RefReadNode* r, const Var& v) final;
+ Expr VisitExpr_(const RefWriteNode* r, const Var& v) final;
+ Expr VisitExpr_(const IfNode* i, const Var& v) final;
+ Expr VisitExpr_(const FunctionNode* f, const Var& v) final;
+ Expr VisitExpr_(const LetNode* l, const Var& v) final;
+ Expr VisitExpr_(const ConstantNode* c, const Var& v) final;
+ Expr VisitExpr_(const VarNode* vn, const Var& v) final;
+ Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final;
+ Expr VisitExpr_(const OpNode* op, const Var& v) final;
+ Expr VisitExpr_(const ConstructorNode* c, const Var& v) final;
+ Expr VisitExpr_(const MatchNode* m, const Var& v) final;
+};
+
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_TRANSFORMS_PASS_UTIL_H_
namespace tvm {
namespace relay {
-struct ScopeNode;
-using Scope = std::shared_ptr<ScopeNode>;
-
-/* Invariant: when parent is null level is 0
- *
- * Invariant: when parent is not null level is 1 + parent->level
- */
-struct ScopeNode {
- size_t level;
- Scope parent;
- std::shared_ptr<LetList> ll = std::make_shared<LetList>();
- explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) {}
- ScopeNode() : level(0) {}
-};
-
-Scope ChildScope(const Scope& s) { return std::make_shared<ScopeNode>(s); }
-
Scope LCA(Scope lhs, Scope rhs) {
while (lhs != rhs) {
if (lhs->level > rhs->level) {
return lhs;
}
-std::unordered_map<DependencyGraph::Node*, Scope> CalcScope(const DependencyGraph& dg) {
- std::unordered_map<DependencyGraph::Node*, Scope> expr_scope;
+std::pair<NodeScopeMap, ExprSet> CalcScope(const DependencyGraph& dg) {
+ NodeScopeMap expr_scope;
+ ExprSet lifted_exprs;
+ std::unordered_map<DependencyGraph::Node*, Expr> node_to_expr;
+ for (auto expr_node : dg.expr_node) {
+ node_to_expr[expr_node.second] = expr_node.first;
+ }
bool global_scope_used = false;
Scope global_scope = std::make_shared<ScopeNode>();
+
for (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) {
DependencyGraph::Node* n = *it;
auto iit = n->parents.head;
global_scope_used = true;
} else {
s = expr_scope.at(iit->value);
+ const auto original_s = s;
iit = iit->next;
for (; iit != nullptr; iit = iit->next) {
s = LCA(s, expr_scope.at(iit->value));
}
+ if (s != original_s && node_to_expr.find(n) != node_to_expr.end()) {
+ // filter out exprs whose scope do not matter
+ Expr expr = node_to_expr[n];
+ if (!expr.as<OpNode>()) {
+ lifted_exprs.insert(expr);
+ }
+ }
+ }
+ if (n->new_scope) {
+ auto child_scope = std::make_shared<ScopeNode>(s);
+ expr_scope.insert({n, child_scope});
+ } else {
+ expr_scope.insert({n, s});
}
- expr_scope.insert({n, n->new_scope ? ChildScope(s) : s});
}
CHECK(global_scope_used);
- return expr_scope;
+ return std::make_pair(expr_scope, lifted_exprs);
}
-/* Special care is needed to handle local recursion.
- * Fill additionally take a (possibly null) Var argument,
- * If it is not null, Fill is required to bind the transformed result to that var.
- */
-class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
- public:
- static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg,
- std::unordered_map<DependencyGraph::Node*, Scope>* node_scope) {
- Fill fi(dg, node_scope);
- return fi.GetScope(e)->ll->Get(fi.VisitExpr(e));
- }
-
- private:
- const DependencyGraph& dg_;
- std::unordered_map<DependencyGraph::Node*, Scope>* node_scope_;
- std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> memo;
+Expr Fill::ToANormalForm(const Expr& e, const DependencyGraph& dg, NodeScopeMap* node_scope) {
+ Fill fi(dg, node_scope, nullptr);
+ return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e));
+}
- Fill(const DependencyGraph& dg, std::unordered_map<DependencyGraph::Node*, Scope>* node_scope)
- : dg_(dg), node_scope_(node_scope) {}
+// For basic block normal form, bind expressions only if the original expression's scope
+// should be lifted
+Expr Fill::ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg,
+ NodeScopeMap* node_scope, ExprSet* lifted) {
+ Fill fi(dg, node_scope, lifted);
+ auto var = fi.VisitExpr(e);
+ return fi.GetScope(e)->let_list->Get(var);
+}
- Scope GetScope(const Expr& e) { return node_scope_->at(dg_.expr_node.at(e)); }
+Scope Fill::GetScope(const Expr& e) { return node_scope_->at(dg_.expr_node.at(e)); }
- Scope GetSubScope(const Expr& e, size_t i) {
- DependencyGraph::Node* n = dg_.expr_node.at(e);
- auto h = n->children.head;
- while (i != 0) {
- CHECK(h);
- --i;
- h = h->next;
- }
+Scope Fill::GetSubScope(const Expr& e, size_t i) {
+ DependencyGraph::Node* n = dg_.expr_node.at(e);
+ auto h = n->children.head;
+ while (i != 0) {
CHECK(h);
- return node_scope_->at(h->value);
+ --i;
+ h = h->next;
}
+ CHECK(h);
+ return node_scope_->at(h->value);
+}
- Expr VisitExpr(const Expr& e, const Var& v) final {
- if (memo.count(e) == 0) {
- memo.insert({e, ExprFunctor<Expr(const Expr&, const Var&)>::VisitExpr(e, v)});
- } else if (v.defined()) {
- GetScope(e)->ll->Push(v, memo.at(e));
- }
- auto ret = memo.at(e);
- CHECK(IsAtomic(ret));
- return ret;
+Expr Fill::VisitExpr(const Expr& e, const Var& v) {
+ if (memo.count(e) == 0) {
+ memo.insert({e, ExprFunctor<Expr(const Expr&, const Var&)>::VisitExpr(e, v)});
+ } else if (v.defined()) {
+ GetScope(e)->let_list->Push(v, memo.at(e));
}
+ auto ret = memo.at(e);
+ // if no include_set is specified, every expression should be atomic.
+ if (include_set_ == nullptr) CHECK(IsAtomic(ret));
+ return ret;
+}
- Expr VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); }
+Expr Fill::VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); }
- Expr Atomic(const Expr& e, const Var& v) { return v.defined() ? GetScope(e)->ll->Push(v, e) : e; }
+Expr Fill::Atomic(const Expr& e, const Var& v) {
+ return v.defined() ? GetScope(e)->let_list->Push(v, e) : e;
+}
- Expr Compound(const Expr& orig, const Expr& now, const Var& v) {
- Var var = v.defined() ? v : Var(String("x"), Type());
- return GetScope(orig)->ll->Push(var, now);
+// Bind expression `now` to var `v` if the original expression is in the include set, or if
+// v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly
+Expr Fill::Compound(const Expr& orig, const Expr& now, const Var& v) {
+ Var var = v.defined() ? v : Var(String("x"), Type());
+ bool not_included = include_set_ && include_set_->find(orig) == include_set_->end();
+ if (!v.defined() && not_included) {
+ return now;
+ } else {
+ return GetScope(orig)->let_list->Push(var, now);
}
+}
- Expr VisitExpr_(const CallNode* c, const Var& v) final {
- Expr e = GetRef<Expr>(c);
- std::vector<Expr> args;
- for (const auto& a : c->args) {
- args.push_back(VisitExpr(a));
- }
- return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v);
+Expr Fill::VisitExpr_(const CallNode* c, const Var& v) {
+ Expr e = GetRef<Expr>(c);
+ std::vector<Expr> args;
+ for (const auto& a : c->args) {
+ args.push_back(VisitExpr(a));
}
+ return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v);
+}
- Expr VisitExpr_(const TupleNode* t, const Var& v) final {
- Expr e = GetRef<Expr>(t);
- std::vector<Expr> fields;
- for (const auto& a : t->fields) {
- fields.push_back(VisitExpr(a));
- }
- return Compound(e, Tuple(fields), v);
+Expr Fill::VisitExpr_(const TupleNode* t, const Var& v) {
+ Expr e = GetRef<Expr>(t);
+ std::vector<Expr> fields;
+ for (const auto& a : t->fields) {
+ fields.push_back(VisitExpr(a));
}
+ return Compound(e, Tuple(fields), v);
+}
- Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final {
- Expr e = GetRef<Expr>(t);
- return Compound(e, TupleGetItem(VisitExpr(t->tuple), t->index), v);
- }
+Expr Fill::VisitExpr_(const TupleGetItemNode* t, const Var& v) {
+ Expr e = GetRef<Expr>(t);
+ return Compound(e, TupleGetItem(VisitExpr(t->tuple), t->index), v);
+}
- Expr VisitExpr_(const RefCreateNode* r, const Var& v) final {
- Expr e = GetRef<Expr>(r);
- return Compound(e, RefCreate(VisitExpr(r->value)), v);
- }
+Expr Fill::VisitExpr_(const RefCreateNode* r, const Var& v) {
+ Expr e = GetRef<Expr>(r);
+ return Compound(e, RefCreate(VisitExpr(r->value)), v);
+}
- Expr VisitExpr_(const RefReadNode* r, const Var& v) final {
- Expr e = GetRef<Expr>(r);
- return Compound(e, RefRead(VisitExpr(r->ref)), v);
- }
+Expr Fill::VisitExpr_(const RefReadNode* r, const Var& v) {
+ Expr e = GetRef<Expr>(r);
+ return Compound(e, RefRead(VisitExpr(r->ref)), v);
+}
- Expr VisitExpr_(const RefWriteNode* r, const Var& v) final {
- Expr e = GetRef<Expr>(r);
- return Compound(e, RefWrite(VisitExpr(r->ref), VisitExpr(r->value)), v);
- }
+Expr Fill::VisitExpr_(const RefWriteNode* r, const Var& v) {
+ Expr e = GetRef<Expr>(r);
+ return Compound(e, RefWrite(VisitExpr(r->ref), VisitExpr(r->value)), v);
+}
- Expr VisitExpr_(const IfNode* i, const Var& v) final {
- Expr e = GetRef<Expr>(i);
- Expr ret = If(VisitExpr(i->cond), GetSubScope(e, 1)->ll->Get(VisitExpr(i->true_branch)),
- GetSubScope(e, 2)->ll->Get(VisitExpr(i->false_branch)));
- return Compound(e, ret, v);
- }
+Expr Fill::VisitExpr_(const IfNode* i, const Var& v) {
+ Expr e = GetRef<Expr>(i);
+ Expr ret = If(VisitExpr(i->cond), GetSubScope(e, 1)->let_list->Get(VisitExpr(i->true_branch)),
+ GetSubScope(e, 2)->let_list->Get(VisitExpr(i->false_branch)));
+ return Compound(e, ret, v);
+}
- Expr VisitExpr_(const FunctionNode* f, const Var& v) final {
- Expr e = GetRef<Expr>(f);
- Expr ret;
- if (f->HasNonzeroAttr(attr::kPrimitive)) {
- ret = e;
- } else {
- ret = Function(f->params, GetSubScope(e, 0)->ll->Get(VisitExpr(f->body)), f->ret_type,
- f->type_params, f->attrs);
- }
- return Compound(e, ret, v);
+Expr Fill::VisitExpr_(const FunctionNode* f, const Var& v) {
+ Expr e = GetRef<Expr>(f);
+ Expr ret;
+ if (f->HasNonzeroAttr(attr::kPrimitive)) {
+ ret = e;
+ } else {
+ ret = Function(f->params, GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body)), f->ret_type,
+ f->type_params, f->attrs);
}
+ return Compound(e, ret, v);
+}
- Expr VisitExpr_(const LetNode* l, const Var& v) final {
- Expr e = GetRef<Expr>(l);
- VisitExpr(l->value, l->var);
- Expr ret = GetSubScope(e, 0)->ll->Get(VisitExpr(l->body));
- return Compound(e, ret, v);
- }
+Expr Fill::VisitExpr_(const LetNode* l, const Var& v) {
+ Expr e = GetRef<Expr>(l);
+ VisitExpr(l->value, l->var);
+ Expr ret = GetSubScope(e, 0)->let_list->Get(VisitExpr(l->body));
+ return Compound(e, ret, v);
+}
- Expr VisitExpr_(const ConstantNode* c, const Var& v) final {
- Expr e = GetRef<Expr>(c);
- return Compound(e, e, v);
- }
+Expr Fill::VisitExpr_(const ConstantNode* c, const Var& v) {
+ Expr e = GetRef<Expr>(c);
+ return Compound(e, e, v);
+}
- Expr VisitExpr_(const VarNode* vn, const Var& v) final {
- Expr e = GetRef<Expr>(vn);
- return Atomic(e, v);
- }
+Expr Fill::VisitExpr_(const VarNode* vn, const Var& v) {
+ Expr e = GetRef<Expr>(vn);
+ return Atomic(e, v);
+}
- Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final {
- GlobalVar gv = GetRef<GlobalVar>(gvn);
- return Atomic(gv, v);
- }
+Expr Fill::VisitExpr_(const GlobalVarNode* gvn, const Var& v) {
+ GlobalVar gv = GetRef<GlobalVar>(gvn);
+ return Atomic(gv, v);
+}
- Expr VisitExpr_(const OpNode* op, const Var& v) final {
- Expr e = GetRef<Expr>(op);
- return Atomic(e, v);
- }
+Expr Fill::VisitExpr_(const OpNode* op, const Var& v) {
+ Expr e = GetRef<Expr>(op);
+ return Atomic(e, v);
+}
- Expr VisitExpr_(const ConstructorNode* c, const Var& v) final {
- Expr e = GetRef<Expr>(c);
- return Atomic(e, v);
- }
+Expr Fill::VisitExpr_(const ConstructorNode* c, const Var& v) {
+ Expr e = GetRef<Expr>(c);
+ return Atomic(e, v);
+}
- Expr VisitExpr_(const MatchNode* m, const Var& v) final {
- Expr e = GetRef<Expr>(m);
- Expr data = VisitExpr(m->data);
- std::vector<Clause> clauses;
- for (const Clause& c : m->clauses) {
- clauses.push_back(
- Clause(c->lhs, GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs))));
- }
- return Compound(e, Match(data, clauses, m->complete), v);
+Expr Fill::VisitExpr_(const MatchNode* m, const Var& v) {
+ Expr e = GetRef<Expr>(m);
+ Expr data = VisitExpr(m->data);
+ std::vector<Clause> clauses;
+ for (const Clause& c : m->clauses) {
+ clauses.push_back(
+ Clause(c->lhs, GetSubScope(e, 1 + clauses.size())->let_list->Get(VisitExpr(c->rhs))));
}
-};
+ return Compound(e, Match(data, clauses, m->complete), v);
+}
Expr ToANormalFormAux(const Expr& e) {
/* When you lift a lambda, what is inside is also being lift.
* Every scope additionally contain a LetList which collect all value of that scope.
* We do an additional pass to fill all the LetList and we are done.
*/
- std::unordered_map<DependencyGraph::Node*, Scope> node_scope = CalcScope(dg);
- return Fill::ToANormalForm(e, dg, &node_scope);
+ std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
+ return Fill::ToANormalForm(e, dg, &scopes.first);
}
IRModule ToANormalForm(const IRModule& m) {
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_basic_block_normal_form.cc
+ *
+ * \brief Turn an expression to the basic normal form.
+ */
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/support/logging.h>
+
+#include "../../support/arena.h"
+#include "../analysis/dependency_graph.h"
+#include "let_list.h"
+#include "pass_util.h"
+
+namespace tvm {
+namespace relay {
+
+Expr ToBasicBlockNormalFormAux(const Expr& e) {
+ // calculate all the dependency between nodes.
+ support::Arena arena;
+ DependencyGraph dg = DependencyGraph::Create(&arena, e);
+ /* The scope of the whole expr is global.
+ * The scope of any subexpr, is the lowest common ancestor of all incoming edge.
+ * We also record the set of expressions whose scope is lifted.
+ */
+ std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
+ return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second);
+}
+
+IRModule ToBasicBlockNormalForm(const IRModule& mod) {
+ DLOG(INFO) << "ToBBlock:" << std::endl << mod;
+
+ tvm::Map<GlobalVar, Function> updates;
+ auto funcs = mod->functions;
+ for (const auto& it : funcs) {
+ CHECK_EQ(FreeVars(it.second).size(), 0) << "Expected no free variables";
+ if (const auto* n = it.second.as<FunctionNode>()) {
+ if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
+ }
+ Expr ret = TransformF([&](const Expr& e) { return ToBasicBlockNormalFormAux(e); }, it.second);
+ updates.Set(it.first, Downcast<Function>(ret));
+ }
+
+ for (auto pair : updates) {
+ mod->Add(pair.first, pair.second, true);
+ }
+
+ DLOG(INFO) << "ToBBlock: transformed" << std::endl << mod;
+
+ return mod;
+}
+
+bool BasicBlockNormalFormCheck(const Expr& e) {
+ // calculate all the dependency between nodes.
+ support::Arena arena;
+ DependencyGraph dg = DependencyGraph::Create(&arena, e);
+ std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
+ for (auto expr : scopes.second) {
+ LOG(FATAL) << "The expression below violates the basic block normal form in that "
+ << "its scope should be lifted:\n"
+ << expr;
+ }
+ return scopes.second.size() == 0;
+}
+
+TVM_REGISTER_GLOBAL("relay.analysis.check_basic_block_normal_form")
+ .set_body_typed(BasicBlockNormalFormCheck);
+
+namespace transform {
+
+Pass ToBasicBlockNormalForm() {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+ [=](IRModule m, PassContext pc) { return relay::ToBasicBlockNormalForm(m); };
+ return CreateModulePass(pass_func, 1, "ToBasicBlockNormalForm", {});
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.ToBasicBlockNormalForm")
+ .set_body_typed(ToBasicBlockNormalForm);
+
+} // namespace transform
+
+} // namespace relay
+} // namespace tvm
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import tvm
+from tvm import relay
+from tvm.relay.analysis import check_basic_block_normal_form
+
+def test_one_block():
+ x = relay.var('x')
+ y = relay.add(x, x)
+ z = relay.add(x, y)
+ check_basic_block_normal_form(z)
+
+def test_let():
+ x = relay.var('x')
+ y = relay.var('y')
+ body = relay.Let(y, x, y)
+ check_basic_block_normal_form(body)
+
+@pytest.mark.xfail(raises=tvm.error.TVMError)
+def test_invalid_if():
+ cond = relay.var('cond', dtype='bool', shape=())
+ shared = relay.var('shared')
+ true_branch = shared
+ false_branch = relay.add(shared, shared)
+ body = relay.If(cond, true_branch, false_branch)
+ """
+ The program below violates basic block normal form, as the scope of %shared
+ is ambiguous and should not be in that of true branch.
+
+ free_var %cond: bool
+ if (%cond) {
+ free_var %shared
+ %shared
+ } else {
+ add(%shared, %shared)
+ }
+ """
+ check_basic_block_normal_form(body)
+
+def test_valid_if():
+ cond = relay.var('cond', dtype='bool', shape=())
+ shared = relay.var('shared')
+ true_branch = shared
+ false_branch = relay.add(shared, shared)
+ body = relay.If(cond, true_branch, false_branch)
+ shared_bound = relay.var('shared_bound', shape=(1,), dtype='float32')
+ body = relay.Let(shared, shared_bound, body)
+ """
+ The program below uses let binding to control the scope of %shared, which
+ follows the basic block normal form.
+
+ free_var %shared_bound: Tensor[(1), float32]
+ let %shared = %shared_bound;
+ free_var %cond: bool
+ if (%cond) {
+ %shared
+ } else {
+ add(%shared, %shared)
+ }
+ """
+ check_basic_block_normal_form(body)
+
+@pytest.mark.xfail(raises=tvm.error.TVMError)
+def test_invalid_if2():
+ """
+ fn (%x: float32) {
+ %0 = equal(%x, 2f);
+ if (%0) {
+ %1 = add(%x, 1f);
+ multiply(%1, 2f)
+ } else {
+ multiply(%1, 1f)
+ }
+ }
+ """
+ x = relay.var('x', shape=(), dtype='float32')
+ one = relay.const(1, dtype='float32')
+ two = relay.const(2, dtype='float32')
+ v1 = relay.add(x, one)
+ v2 = relay.equal(x, two)
+ true_branch = relay.multiply(v1, two)
+ false_branch = relay.multiply(v1, one)
+ body = relay.If(v2, true_branch, false_branch)
+ func = relay.Function([x], body)
+ check_basic_block_normal_form(func)
+
+def test_valid_if2():
+ """
+ fn (%x: float32) {
+ let %v1 = add(%x, 1f);
+ %0 = equal(%x, 2f);
+ if (%0) {
+ multiply(%v1, 2f)
+ } else {
+ multiply(%v1, 1f)
+ }
+ }
+ """
+ x = relay.var('x', shape=(), dtype='float32')
+ one = relay.const(1, dtype='float32')
+ two = relay.const(2, dtype='float32')
+ v1 = relay.var('v1')
+ v2 = relay.equal(x, two)
+ true_branch = relay.multiply(v1, two)
+ false_branch = relay.multiply(v1, one)
+ body = relay.If(v2, true_branch, false_branch)
+ body = relay.Let(v1, relay.add(x, one), body)
+ func = relay.Function([x], body)
+ check_basic_block_normal_form(func)
+
+@pytest.mark.xfail(raises=tvm.error.TVMError)
+def test_func():
+ x = relay.var('x', shape=(1,), dtype='float32')#, a)
+ y = relay.var('y', shape=(1,), dtype='float32')#, a)
+ z = relay.var('z', shape=(1,), dtype='float32')#, a)
+ x2 = relay.add(x, x)
+ func_a = relay.Function([y], relay.add(x2, y)) #, a, [a])
+ func_b = relay.Function([z], relay.add(x2, z)) #, a, [a])
+ body = relay.Tuple([func_a, func_b])
+ body = relay.Function([x], body)
+ """
+ fn (%x: Tensor[(1), float32]) {
+ %1 = fn (%y: Tensor[(1), float32]) {
+ %0 = add(%x, %x);
+ add(%0, %y)
+ };
+ %2 = fn (%z: Tensor[(1), float32]) {
+ add(%0, %z)
+ };
+ (%1, %2)
+ }
+ """
+ check_basic_block_normal_form(body)
+
+@pytest.mark.xfail(raises=tvm.error.TVMError)
+def test_higher_order_return():
+ x = relay.var('x', shape=(1,), dtype='float32')#, a)
+ y = relay.var('y', shape=(1,), dtype='float32')#, a)
+ z = relay.var('z', shape=(1,), dtype='float32')#, a)
+ x2 = relay.add(x, x)
+ func_a = relay.Function([y], relay.add(x2, y)) #, a, [a])
+ func_b = relay.Function([z], relay.add(x2, z)) #, a, [a])
+ body = relay.Tuple([func_a, func_b])
+ body = relay.Function([x], body)
+ """
+ fn (%x: Tensor[(1), float32]) {
+ %1 = fn (%y: Tensor[(1), float32]) {
+ %0 = add(%x, %x);
+ add(%0, %y)
+ };
+ %2 = fn (%z: Tensor[(1), float32]) {
+ add(%0, %z)
+ };
+ (%1, %2)
+ }
+ """
+ check_basic_block_normal_form(body)
+
+
+@pytest.mark.xfail(raises=tvm.error.TVMError)
+def test_higher_order_nested():
+ x = relay.var('x', dtype='float32', shape=(1,))
+ s = relay.var('s', dtype='float32', shape=(1,))
+ shared = relay.add(s, s)
+ func_true = relay.Function([x], relay.add(x, shared))
+ choice_t = relay.FuncType([], relay.scalar_type('bool'))
+ f = relay.Var('f', choice_t)
+ z = relay.Var('z')
+ body = relay.If(f(), func_true, relay.Function([z], relay.add(z, shared)))
+ top = relay.Function([f, s], body)
+ """
+ fn (%f: fn () -> bool, %s: Tensor[(1), float32]) {
+ %0 = %f();
+ if (%0) {
+ fn (%x: Tensor[(1), float32]) {
+ %1 = add(%s, %s);
+ add(%x, %1)
+ }
+ } else {
+ fn (%z) {
+ add(%z, %1)
+ }
+ }
+ }
+ """
+ check_basic_block_normal_form(top)
+
+
+if __name__ == '__main__':
+ pytest.main([__file__])
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import numpy as np
+import tvm
+from tvm import te
+from tvm import relay
+from tvm.relay.analysis import detect_feature
+from tvm.relay import op, create_executor, transform
+from tvm.relay.prelude import Prelude
+from tvm.relay.testing import add_nat_definitions, count
+from tvm.relay.analysis import Feature
+from tvm.relay.analysis import check_basic_block_normal_form
+
+
+def run_opt_pass(expr, passes):
+ passes = passes if isinstance(passes, list) else [passes]
+ mod = tvm.IRModule.from_expr(expr)
+ seq = tvm.transform.Sequential(passes)
+ with tvm.transform.PassContext(opt_level=3):
+ mod = seq(mod)
+ entry = mod["main"]
+ return entry if isinstance(expr, relay.Function) else entry.body
+
+
+def check_eval(expr, expected_result, mod=None, rtol=1e-07):
+ ctx = tvm.context("llvm", 0)
+ intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
+
+ result = intrp.evaluate(expr)
+ np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
+
+
+def test_no_explicit_bind():
+ x = relay.const(1)
+ y = op.add(x, x)
+ z = op.add(y, y)
+ f = relay.Function([], op.add(z, z))
+ """
+ fn () {
+ %0 = add(1, 1);
+ %1 = add(%0, %0);
+ add(%1, %1)
+ }
+ """
+ assert not Feature.fLet in detect_feature(f)
+ bblock = run_opt_pass(f, transform.ToBasicBlockNormalForm())
+ assert Feature.fLet not in detect_feature(bblock)
+ check_eval(f(), 8.0)
+ check_eval(bblock(), 8.0)
+ check_basic_block_normal_form(bblock)
+
+def test_top_level_nested_if():
+ x = relay.var('x', shape=(), dtype='bool')
+ y = relay.var('y', shape=(), dtype='float32')
+ z = relay.var('z', shape=(), dtype='float32')
+ cond_t = relay.const(True)
+ cond_f = relay.const(False)
+ one = relay.const(1, dtype='float32')
+ three = relay.const(3, dtype='float32')
+ y2 = relay.add(y, y)
+ z2 = relay.add(z, z)
+ true_branch = relay.If(cond_t, relay.add(z2, y2), relay.add(three, y2))
+ false_branch = relay.If(cond_f, z2, one)
+ body = relay.If(x, true_branch, false_branch)
+ """
+ free_var %x: bool
+ if (%x) {
+ if (True) {
+ free_var %z: float32
+ %0 = add(%z, %z);
+ free_var %y: float32
+ %1 = add(%y, %y);
+ add(%0, %1)
+ } else {
+ add(3f, %1)
+ }
+ } else {
+ if (False) {
+ %0
+ } else {
+ 1f
+ }
+ }
+ """
+ def expected():
+ x = relay.var('x', shape=(), dtype='bool')
+ y = relay.var('y', shape=(), dtype='float32')
+ z = relay.var('z', shape=(), dtype='float32')
+ cond_t = relay.const(True)
+ cond_f = relay.const(False)
+ one = relay.const(1, dtype='float32')
+ three = relay.const(3, dtype='float32')
+ y2 = relay.var('y2')
+ z2 = relay.var('z2')
+ true_branch = relay.If(cond_t, relay.add(z2, y2), relay.add(three, y2))
+ true_branch = relay.Let(y2, relay.add(y, y), true_branch)
+ false_branch = relay.If(cond_f, z2, one)
+ body = relay.If(x, true_branch, false_branch)
+ body = relay.Let(z2, relay.add(z, z), body)
+ return body
+
+ bblock = run_opt_pass(body, [transform.ToBasicBlockNormalForm()])
+ """
+ free_var %z: float32
+ let %x: float32 = add(%z, %z) /* ty=float32 */;
+ free_var %x1: bool
+ if (%x1) {
+ free_var %y: float32
+ let %x2: float32 = add(%y, %y) /* ty=float32 */;
+ if (True /* ty=bool */) {
+ add(%x, %x2) /* ty=float32 */
+ } else {
+ add(3f /* ty=float32 */, %x2) /* ty=float32 */
+ }
+ } else {
+ if (False /* ty=bool */) {
+ %x
+ } else {
+ 1f /* ty=float32 */
+ }
+ }
+ """
+ expected_output = run_opt_pass(expected(), transform.InferType())
+ assert tvm.ir.structural_equal(bblock, expected_output, map_free_vars=True)
+
+def test_nested_if():
+ x = relay.var('x', shape=(), dtype='bool')
+ y = relay.var('y', shape=(), dtype='float32')
+ cond_t = relay.const(True)
+ cond_f = relay.const(False)
+ one = relay.const(1, dtype='float32')
+ two = relay.const(2, dtype='float32')
+ three = relay.const(3, dtype='float32')
+ y2 = relay.add(y, y)
+ true_branch = relay.If(cond_t, y2, relay.add(three, y2))
+ false_branch = relay.If(cond_f, two, one)
+ body = relay.If(x, true_branch, false_branch)
+ """
+ free_var %x: bool
+ if (%x) {
+ if (True) {
+ free_var %y: float32
+ %0 = add(%y, %y);
+ %0
+ } else {
+ add(3f, %0)
+ }
+ } else {
+ if (False) {
+ 2f
+ } else {
+ 1f
+ }
+ }
+ """
+ def expected():
+ x = relay.var('x', shape=(), dtype='bool')
+ y = relay.var('y', shape=(), dtype='float32')
+ cond_t = relay.const(True)
+ cond_f = relay.const(False)
+ one = relay.const(1, dtype='float32')
+ two = relay.const(2, dtype='float32')
+ three = relay.const(3, dtype='float32')
+ y2 = relay.var('y2')
+ true_branch = relay.If(cond_t, y2, relay.add(three, y2))
+ true_branch = relay.Let(y2, relay.add(y, y), true_branch)
+ false_branch = relay.If(cond_f, two, one)
+ body = relay.If(x, true_branch, false_branch)
+ return body
+
+ bblock = run_opt_pass(body, [transform.ToBasicBlockNormalForm()])
+ """
+ free_var %x: bool
+ if (%x) {
+ free_var %y: float32
+ let %x1: float32 = add(%y, %y) /* ty=float32 */;
+ if (True /* ty=bool */) {
+ %x1
+ } else {
+ add(3f /* ty=float32 */, %x1) /* ty=float32 */
+ }
+ } else {
+ if (False /* ty=bool */) {
+ 2f /* ty=float32 */
+ } else {
+ 1f /* ty=float32 */
+ }
+ }
+ """
+ expected_output = run_opt_pass(expected(), transform.InferType())
+ assert tvm.ir.structural_equal(bblock, expected_output, map_free_vars=True)
+ check_basic_block_normal_form(bblock)
+
+
+# make sure we do not infinite loop.
+# it is too large so we won't check for the exact program.
+def test_recursion():
+ """
+ Program:
+ let f(n: i32) -> i32 = {
+ m = (n * 2)
+ if (n == 0) {
+ return m;
+ } else {
+ return m + f(n - 1);
+ }
+ }
+ f(5);
+ """
+ mod = tvm.IRModule()
+ i64 = relay.TensorType((), 'int64')
+ f = relay.GlobalVar("f")
+ n = relay.Var("n", i64)
+ m = n * relay.const(2, 'int64')
+ cond = relay.equal(n, relay.const(0, 'int64'))
+ false_branch = m + f(n - relay.const(1, 'int64'))
+ funcbody = relay.If(cond, m, false_branch)
+ value = relay.Function([n], funcbody, i64, [])
+ mod[f] = value
+ check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
+ old_f = mod[f]
+ mod = transform.ToBasicBlockNormalForm()(mod)
+ f = mod[f]
+ check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
+ check_basic_block_normal_form(f)
+
+def test_ref():
+ i = relay.Var('i')
+ iv = relay.Var('iv')
+ u = relay.Var('u')
+ uv = relay.Var('uv')
+ body = relay.add(iv, uv)
+ body = relay.Let(uv, relay.RefRead(i), body)
+ body = relay.Let(u, relay.RefWrite(i, relay.const(2)), body)
+ body = relay.Let(iv, relay.RefRead(i), body)
+ body = relay.Let(i, relay.RefCreate(relay.const(1)), body)
+ check_eval(body, 3)
+ opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm())
+ check_eval(opt_body, 3)
+ check_basic_block_normal_form(opt_body)
+
+
+def test_nat_add():
+ mod = tvm.IRModule()
+ p = Prelude(mod)
+ add_nat_definitions(p)
+ nat = p.nat
+ add = p.add
+ s = p.s
+ z = p.z
+ ctx = tvm.context("llvm", 0)
+ intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
+ assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
+ assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2
+ expr = add(s(z()), s(z()))
+ f = relay.GlobalVar("f")
+ mod[f] = relay.Function([], expr)
+ mod = transform.ToBasicBlockNormalForm()(mod)
+ opt_expr = mod["f"]
+ assert count(p, intrp.evaluate(opt_expr.body)) == 2
+ assert not Feature.fLet in detect_feature(mod[add])
+ check_basic_block_normal_form(opt_expr)
+
+def test_let():
+ def test_let1():
+ x = relay.Var("x")
+ c = relay.const(4.0, 'float32')
+ body = relay.Let(x, c, x)
+ body = run_opt_pass(body, transform.InferType())
+ """
+ let %x: float32 = 4f /* ty=float32 */;
+ %x
+ """
+ opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm())
+ assert tvm.ir.structural_equal(body, opt_body)
+ check_basic_block_normal_form(opt_body)
+
+ def test_let1_1():
+ x = relay.Var("y")
+ d = relay.const(4.0, 'float32')
+ body = relay.Let(x, d, relay.add(x,x))
+ body = run_opt_pass(body, transform.InferType())
+ opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm())
+ assert tvm.ir.structural_equal(body, opt_body)
+ check_basic_block_normal_form(opt_body)
+
+ def test_let2():
+ x = relay.Var("x")
+ y = relay.Var("y")
+ d = relay.const(4.0, 'float32')
+ body = relay.Let(y, x, x)
+ body = relay.Let(x, d, body)
+ body = run_opt_pass(body, transform.InferType())
+ check_eval(body, 4)
+
+ def expected():
+ x = relay.Var("x")
+ y = relay.Var("y")
+ d = relay.const(4.0, 'float32')
+ body = relay.Let(y, x, y)
+ body = relay.Let(x, d, body)
+ return body
+
+ opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm())
+ expected_body = run_opt_pass(expected(), transform.InferType())
+ assert tvm.ir.structural_equal(opt_body, expected_body)
+ check_basic_block_normal_form(opt_body)
+
+ def test_let3():
+ x = relay.Var("x")
+ y = relay.Var("y")
+ z = relay.Var("z")
+ c = relay.const(3.0, 'float32')
+ d = relay.const(4.0, 'float32')
+ body = relay.Let(z, x + y, x + z)
+ body = relay.Let(x, d, body)
+ body = relay.Let(y, c, body)
+ body = run_opt_pass(body, transform.InferType())
+ opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm())
+ assert tvm.ir.structural_equal(body, opt_body)
+ check_basic_block_normal_form(opt_body)
+
+ test_let1()
+ test_let1_1()
+ test_let2()
+ test_let3()
+
+def test_function():
+ t = relay.TensorType((), 'float32')
+ x = relay.Var("x", t)
+ f = relay.Function([x], x + x)
+ d = relay.const(4.0, 'float32')
+ bblock = run_opt_pass(f, transform.ToBasicBlockNormalForm())
+ assert isinstance(bblock, relay.Function)
+ check_eval(f(d), 8)
+ check_eval(bblock(d), 8)
+ check_basic_block_normal_form(bblock)
+
+def test_gradient_if():
+ x = relay.var("a", shape=(1, 16))
+ y = relay.var("y", shape=(1, 16))
+ cond = relay.var("cond", shape=(), dtype='uint1')
+ net = relay.If(cond, x, x)
+ net = relay.add(x, net)
+ net = relay.Function([cond,x,y], net)
+ mod = tvm.IRModule.from_expr(net)
+ mod = relay.transform.ToBasicBlockNormalForm()(mod)
+ net_grad = relay.transform.gradient(mod["main"], mode='higher_order')
+ mod["main"] = net_grad
+ mod_grad = relay.transform.ToBasicBlockNormalForm()(mod)
+ check_basic_block_normal_form(mod_grad['main'])
+ check_basic_block_normal_form(mod['main'])
+
+def test_if():
+ def if_expr(x):
+ """
+ free_var %x: float32
+ %0 = equal(%x, 2f);
+ if (%0) {
+ %1 = add(%x, 1f);
+ multiply(%1, 2f)
+ } else {
+ multiply(%1, 1f)
+ }
+ """
+ one = relay.const(1, dtype='float32')
+ two = relay.const(2, dtype='float32')
+ v1 = relay.add(x, one)
+ v2 = relay.equal(x, two)
+ true_branch = relay.multiply(v1, two)
+ false_branch = relay.multiply(v1, one)
+ body = relay.If(v2, true_branch, false_branch)
+ return body
+
+ def expected_if_expr(x):
+ """
+ free_var %x: float32
+ let %v1: float32 = add(%x, 1f /* ty=float32 */) /* ty=float32 */;
+ %0 = equal(%x, 2f /* ty=float32 */) /* ty=bool */;
+ if (%0) {
+ multiply(%v1, 2f /* ty=float32 */) /* ty=float32 */
+ } else {
+ multiply(%v1, 1f /* ty=float32 */) /* ty=float32 */
+ }
+ """
+ one = relay.const(1, dtype='float32')
+ two = relay.const(2, dtype='float32')
+ v1 = relay.var('v1')
+ v2 = relay.equal(x, two)
+ true_branch = relay.multiply(v1, two)
+ false_branch = relay.multiply(v1, one)
+ body = relay.If(v2, true_branch, false_branch)
+ body = relay.Let(v1, relay.add(x, one), body)
+ return body
+
+ x = relay.var('x', shape=(), dtype='float32')
+ body = if_expr(x)
+ expected_body = expected_if_expr(x)
+ bblock = run_opt_pass(body, transform.ToBasicBlockNormalForm())
+ expected_bblock = run_opt_pass(expected_body, transform.InferType())
+ assert tvm.ir.structural_equal(bblock, expected_bblock, map_free_vars=True)
+ check_basic_block_normal_form(bblock)
+
+ func = relay.Function([x], body)
+ expected_func = relay.Function([x], expected_body)
+ bblock = run_opt_pass(func, transform.ToBasicBlockNormalForm())
+ expected_bblock = run_opt_pass(expected_func, transform.InferType())
+ assert tvm.ir.structural_equal(bblock, expected_bblock)
+ check_basic_block_normal_form(bblock)
+
+def test_higher_order_return():
+ x = relay.var('x', shape=(1,), dtype='float32')#, a)
+ y = relay.var('y', shape=(1,), dtype='float32')#, a)
+ z = relay.var('z', shape=(1,), dtype='float32')#, a)
+ x2 = relay.add(x, x)
+ func_a = relay.Function([y], relay.add(x2, y)) #, a, [a])
+ func_b = relay.Function([z], relay.add(x2, z)) #, a, [a])
+ body = relay.Tuple([func_a, func_b])
+ body = relay.Function([x], body)
+ """
+ fn (%x: Tensor[(1), float32]) {
+ %1 = fn (%y: Tensor[(1), float32]) {
+ %0 = add(%x, %x);
+ add(%0, %y)
+ };
+ %2 = fn (%z: Tensor[(1), float32]) {
+ add(%0, %z)
+ };
+ (%1, %2)
+ }
+ """
+
+ bblock = run_opt_pass(body, transform.ToBasicBlockNormalForm())
+ check_basic_block_normal_form(bblock)
+
+
+def test_higher_order_nested():
+ x = relay.var('x', dtype='float32', shape=(1,))
+ s = relay.var('s', dtype='float32', shape=(1,))
+ shared = relay.add(s, s)
+ func_true = relay.Function([x], relay.add(x, shared))
+ choice_t = relay.FuncType([], relay.scalar_type('bool'))
+ f = relay.Var('f', choice_t)
+ z = relay.Var('z')
+ body = relay.If(f(), func_true, relay.Function([z], relay.add(z, shared)))
+ top = relay.Function([f, s], body)
+ """
+ fn (%f: fn () -> bool, %s: Tensor[(1), float32]) {
+ %0 = %f();
+ if (%0) {
+ fn (%x: Tensor[(1), float32]) {
+ %1 = add(%s, %s);
+ add(%x, %1)
+ }
+ } else {
+ fn (%z) {
+ add(%z, %1)
+ }
+ }
+ }
+ """
+
+ bblock = run_opt_pass(top, transform.ToBasicBlockNormalForm())
+ check_basic_block_normal_form(bblock)
+
+if __name__ == '__main__':
+ pytest.main([__file__])