* For example, this pass should turn `let a = 1 in 2` into `2`,
* as the value of the expression does not depend on a.
*
- * As another example, `let a = 1 in a` will be optimized into 1.
+ * As another example, `let a = 1 in a` will be optimized into 1,
+ * if the flag is turned on.
*
* \param e the expression to optimize.
+ * \param inline_once whether or not to inline binding used one.
*
* \return the optimized expression.
*/
-TVM_DLL Expr DeadCodeElimination(const Expr& e);
+TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false);
/*!
* \brief Fold constant expressions.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*
- * \param e the expression,
+ * \param e the expression
+ * \param mod the module
*
* \return the optimized expression.
*/
-TVM_DLL Expr PartialEval(const Expr& e);
+TVM_DLL Expr PartialEval(const Expr& e, const Module& mod);
/*!
* \brief Bind the free variables to a Relay expression.
*
* As another example, `let a = 1 in a` will be optimized into 1.
*
+ * \param inline_once whether or not to inline binding used one.
+ *
* \return the pass.
*/
-TVM_DLL Pass DeadCodeElimination();
+TVM_DLL Pass DeadCodeElimination(bool inline_once = false);
/*!
* \brief Fold constant expressions.
Parameters
----------
- expr: tvm.relay.Expr
+ expr : tvm.relay.Expr
The input expression
Returns
Parameters
----------
- expr: tvm.relay.Expr
+ expr : tvm.relay.Expr
The input expression
Returns
Parameters
----------
- expr: tvm.relay.Expr
+ expr : tvm.relay.Expr
The input expression
Returns
Parameters
----------
- expr: tvm.relay.Expr
+ expr : tvm.relay.Expr
The input expression
Returns
Parameters
----------
- expr: Union[tvm.relay.Expr,tvm.relay.Type]
+ expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
- mod: tvm.relay.Module, optional
+
+ mod : Optional[tvm.relay.Module]
The global module
Returns
Parameters
----------
- expr: Union[tvm.relay.Expr,tvm.relay.Type]
+ expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
- mod: tvm.relay.Module, optional
+
+ mod : Optional[tvm.relay.Module]
The global module
Returns
Parameters
----------
- expr: Union[tvm.relay.Expr,tvm.relay.Type]
+ expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
- mod: tvm.relay.Module, optional
+ mod : Optional[tvm.relay.Module]
The global module
Returns
Parameters
----------
- e: tvm.relay.Expr
+ expr : tvm.relay.Expr
The input Expression
Returns
-------
- result: tvm.relay.Expr
+ result : tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with some simplification
"""
Parameters
----------
- e: tvm.relay.Expr
+ expr : tvm.relay.Expr
The input Expression
Returns
-------
- result: tvm.relay.Expr
+ result : tvm.relay.Expr
An expression without bias_add
"""
return _ir_pass.canonicalize_ops(expr)
-def dead_code_elimination(expr):
+def dead_code_elimination(expr, inline_once=False):
""" Remove expressions which does not effect the program result (dead code).
Parameters
----------
- e: tvm.relay.Expr
+ expr : tvm.relay.Expr
The input Expression
+ inline_once : Optional[Bool]
+ Whether to inline binding that occur only once.
Returns
-------
- result: tvm.relay.Expr
+ result : tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with dead code removed.
"""
- return _ir_pass.dead_code_elimination(expr)
+ return _ir_pass.dead_code_elimination(expr, inline_once)
def alpha_equal(lhs, rhs):
Parameters
----------
- lhs: tvm.relay.Expr
+ lhs : tvm.relay.Expr
One of the input Expression.
- rhs: tvm.relay.Expr
+ rhs : tvm.relay.Expr
One of the input Expression.
Returns
-------
- result: bool
+ result : bool
True iff lhs is alpha equal to rhs.
"""
return bool(_make._alpha_equal(lhs, rhs))
Parameters
----------
- lhs: tvm.relay.Expr
+ lhs : tvm.relay.Expr
One of the input Expression.
- rhs: tvm.relay.Expr
+ rhs : tvm.relay.Expr
One of the input Expression.
Returns
-------
- result: bool
+ result : bool
True iff lhs is data-flow equivalent to rhs.
"""
return bool(_make._graph_equal(lhs, rhs))
Parameters
----------
- expr: tvm.relay.Expr or tvm.relay.Type
+ expr : Union[tvm.relay.Expr, tvm.relay.Type]
The expression to hash.
Returns
-------
- result: int
+ result : int
The hash value
"""
if isinstance(value, Expr):
expr : tvm.relay.Expr
The input expression.
- mod: Optional[tvm.relay.Module]
+ mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
- expr: tvm.relay.Expr
+ result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.to_a_normal_form(expr, mod)
The input expression
Returns
-------
- expr : tvm.relay.Expr
+ result : tvm.relay.Expr
The output expression
"""
return _ir_pass.to_graph_normal_form(expr)
Returns
-------
- ret : int64
+ result : int64
The number of MACs (multiply-accumulate) of a model
"""
return _ir_pass.GetTotalMacNumber(expr)
expr : tvm.relay.Expr
The input expression.
- fskip: function
+ fskip : function
The callback function that decides whether an expression should be skipped.
Returns
-------
- expr : tvm.relay.Expr
+ result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.eliminate_common_subexpr(expr, fskip)
-def partial_evaluate(expr):
+def partial_evaluate(expr, mod=None):
"""
Evaluate the static fragment of the code.
expr : tvm.relay.Expr
The input expression.
+ mod : Optional[tvm.relay.Module]
+ The global module
+
Returns
-------
- expr : tvm.relay.Expr
+ result : tvm.relay.Expr
The output expression.
"""
- return _ir_pass.partial_evaluate(expr)
+ return _ir_pass.partial_evaluate(expr, mod)
def unmatched_cases(match, mod=None):
"""
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
- p->stream << "CallNode(" << node->op << ", " << node->args << ", "
- << node->attrs << ", " << node->type_args << ")";
+ p->stream << "CallNode(" << node->op << ", " << node->args << ", "
+ << node->attrs << ", " << node->type_args << ")";
});
Let LetNode::make(Var var, Expr value, Expr body) {
TVM_REGISTER_API("relay._expr.TempExprRealize")
.set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
- return temp->Realize();
+ return temp->Realize();
});
} // namespace relay
// calculate the dependency graph from expression
class CalcDep : private ExprVisitor {
public:
- static Expr Eliminate(const Expr& e) {
+ static Expr Eliminate(const Expr& e, bool inline_once) {
CalcDep cd;
cd.Calculate(e);
- Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_);
+ Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once);
return el(e);
}
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
VarSet letrec_set_;
+ bool inline_once_;
explicit Eliminator(const VarMap<Expr>& expr_map,
const VarMap<size_t>& use_map,
- const VarSet& letrec_set) :
- expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { }
+ const VarSet& letrec_set,
+ bool inline_once) :
+ expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { }
friend CalcDep;
bool HasLet(const Var& v) {
- // TODO(@jroesch): MK fix me
- return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
+ switch (use_map_[v]) {
+ case 0:
+ return false;
+ case 1:
+ return letrec_set_.count(v) > 0 || !inline_once_;
+ default:
+ return true;
+ }
}
Expr VisitExpr_(const VarNode* op) final {
};
};
-Expr DeadCodeElimination(const Expr& e) {
- return CalcDep::Eliminate(e);
+Expr DeadCodeElimination(const Expr& e, bool inline_once) {
+ return CalcDep::Eliminate(e, inline_once);
}
TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
namespace transform {
-Pass DeadCodeElimination() {
+Pass DeadCodeElimination(bool inline_once) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
- return Downcast<Function>(DeadCodeElimination(f));
+ return Downcast<Function>(DeadCodeElimination(f, inline_once));
};
return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
}
*
* The partial evaluator makes several assumptions, so there is room for improvement:
*
- * 0: The partial evaluator treats global variables as opaque.
- * Doing PartialEval on a module level will solve this.
- *
- * 1: The partial evaluator assume all functions as terminating.
- * We need to has a max_expand parameter that shrink on every compile time evaluation,
- * to make sure PE does not infinite loop.
- * Additionally, we might add a termination analysis pass that lift this requirement
- * for function that analysis found terminating.
- *
- * 2: Every time an unknown effect happened, we clear the whole store.
+ * 0: Every time an unknown effect happened, we clear the whole store.
* It is too conservative: if a local reference is created (and do not get passed outside),
* An unknown global function call/global reference write can not modify it.
* We can pair PE with escape analysis/alias analysis.
*
- * 3: We assume all unknown code has effect. Doing effect analysis can make the store more precise.
+ * 1: We assume all unknown code has effect. Doing effect analysis can make the store more precise.
*
- * 4: When doing pattern matching, we can simplify the match even for dynamic case.
+ * 2: When doing pattern matching, we can simplify the match even for dynamic case.
* Right now it is all or nothing: either a complete match, or the original dynamic code.
* Instead, we can get a match tree, pair it with the data and evaluate it to a normal form.
* We then can reify the result.
*
- * 5: Every time a function is called, it's code will get expanded and partially evaluated.
+ * 3: Every time a function is called, its code will get expanded and partially evaluated.
* We can do a binding time analysis to cache the result and avoid re-partial evaluation.
*
* These assumptions do not affect the correctness of the algorithm, however.
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/interpreter.h>
+#include "../ir/type_functor.h"
#include "pass_util.h"
#include "let_list.h"
}
};
+Expr PostProcess(const Expr&);
+
/*! \brief The base container type of Relay values. */
class StaticNode : public RelayNode {
public:
using ContainerType = StaticNode;
};
+using Time = size_t;
+
struct PStaticNode : Node {
+ static Time time() {
+ static Time time_ = 0;
+ Time ret = time_;
+ time_++;
+ return ret;
+ }
Static pstatic; // may be null
Expr dynamic;
- PStaticNode(const Static& pstatic, const Expr& dynamic) : pstatic(pstatic), dynamic(dynamic) { }
+ Time created_time;
+ PStaticNode(const Static& pstatic, const Expr& dynamic) :
+ pstatic(pstatic), dynamic(dynamic), created_time(time()) { }
explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { }
TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node);
};
};
PStatic HasStatic(const Static& stat, const Expr& dynamic) {
+ CHECK(stat.defined());
return PStatic(make_node<PStaticNode>(stat, dynamic));
}
return CreateInterpreter(Module(nullptr), CPUContext(), target);
}
+bool IsAtomic(const Expr& e) {
+ return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>();
+}
+
+using FuncId = int;
+
+/*!
+ * \brief Annotate a function with a FuncId.
+ */
+struct WithFuncIdAttrs : public tvm::AttrsNode<WithFuncIdAttrs> {
+ FuncId fid;
+
+ TVM_DECLARE_ATTRS(WithFuncIdAttrs, "relay.attrs.WithFuncIdAttrs") {
+ TVM_ATTR_FIELD(fid)
+ .describe("The FuncId that an function is annotated with.")
+ .set_default(-1);
+ }
+};
+
+TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs);
+
+Op WithFuncIdOp() {
+ static const Op& op = Op::Get("annotation.with_funcid");
+ return op;
+}
+
+Expr MkWithFuncId(const Expr& expr, FuncId fid) {
+ auto attrs = make_node<WithFuncIdAttrs>();
+ attrs->fid = fid;
+ return CallNode::make(WithFuncIdOp(), {expr}, Attrs(attrs), {});
+}
+
+RELAY_REGISTER_OP("annotation.with_funcid")
+.describe(R"code(Annotate a function with a funcid.)code"
+TVM_ADD_FILELINE)
+.set_num_inputs(1)
+.add_argument("func", "Function", "The input data.");
+
+Expr StripWithFuncId(const Expr& e);
+
+Expr DeDup(const Expr& e);
+
+Function AsFunc(const Expr& e) {
+ if (e.as<FunctionNode>()) {
+ return Downcast<Function>(e);
+ } else if (const CallNode* c = e.as<CallNode>()) {
+ CHECK(c->op.same_as(WithFuncIdOp()));
+ CHECK_EQ(c->args.size(), 1);
+ return AsFunc(c->args[0]);
+ } else {
+ LOG(FATAL) << "Unknown case";
+ throw;
+ }
+}
+
class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>,
public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> {
public:
- PartialEvaluator(const tvm::Array<Var>& free_vars) {
+ PartialEvaluator(const tvm::Array<Var>& free_vars,
+ const Module& mod) :
+ mod_(mod) {
for (const Var& v : free_vars) {
env_.Insert(v, NoStatic(v));
}
}
+ PStatic VisitExpr(const Expr& e, LetList* ll) final {
+ PStatic ret = ExprFunctor<PStatic(const Expr&, LetList*)>::VisitExpr(e, ll);
+ CHECK(IsAtomic(ret->dynamic)) << ret->dynamic;
+ return ret;
+ }
+
PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final {
return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef<Expr>(op)));
}
}
PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final {
- return NoStatic(GetRef<Expr>(op));
+ GlobalVar gv = GetRef<GlobalVar>(op);
+ if (gv_map_.count(gv) == 0) {
+ if (mod_.defined()) {
+ Function func = mod_->Lookup(gv);
+ InitializeFuncId(func);
+ Func f = VisitFuncStatic(func, gv);
+ gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
+ func = AsFunc(PostProcess(VisitFuncDynamic(func, f)));
+ mod_->Update(gv, func);
+ } else {
+ gv_map_.insert({gv, NoStatic(gv)});
+ }
+ }
+ return gv_map_.at(gv);
}
PStatic VisitExpr_(const LetNode* op, LetList* ll) final {
}
PStatic VisitExpr_(const CallNode* op, LetList* ll) final {
+ if (op->op.same_as(WithFuncIdOp())) {
+ CHECK_EQ(op->args.size(), 1);
+ return VisitExpr(op->args[0], ll);
+ }
PStatic f = VisitExpr(op->op, ll);
std::vector<PStatic> x;
tvm::Array<Expr> x_dyn;
}
}
- PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final {
- Function func = GetRef<Function>(op);
+ struct TimeFrame {
+ PartialEvaluator* pe_;
+ FuncId fid_;
+ std::vector<Time> old_time;
+ bool has_old_time;
+ TimeFrame(PartialEvaluator* pe,
+ FuncId fid,
+ const std::vector<Time>& args_time) : pe_(pe), fid_(fid) {
+ has_old_time = pe_->time_map_.count(fid_) > 0;
+ old_time = pe_->time_map_[fid_];
+ pe_->time_map_[fid_] = args_time;
+ }
+ ~TimeFrame() {
+ if (has_old_time) {
+ pe_->time_map_[fid_] = old_time;
+ } else {
+ pe_->time_map_.erase(fid_);
+ }
+ }
+ };
+
+ Func VisitFuncStatic(const Function& func, const Expr& var) {
+ CHECK(IsAtomic(var));
if (func->IsPrimitive()) {
- return HasStatic(MkSFunc(ConstEvaluateFunc(func, ll)), func);
+ return ConstEvaluateFunc(func);
}
std::vector<std::pair<Var, PStatic> > free_vars;
- for (const auto& v : FreeVars(GetRef<Expr>(op))) {
+ for (const auto& v : FreeVars(func)) {
free_vars.push_back(std::pair<Var, PStatic>(v, env_.Lookup(v)));
}
- Func f = [=](const std::vector<PStatic>& pv,
- const Attrs& attrs,
- const tvm::Array<Type>& type_args,
- LetList* ll) {
+ return [=](const std::vector<PStatic>& pv,
+ const Attrs& attrs,
+ const tvm::Array<Type>& type_args,
+ LetList* ll) {
return env_.Extend<PStatic>([&]() {
CHECK_EQ(pv.size(), func->params.size());
for (size_t i = 0; i < pv.size(); ++i) {
for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
subst.Set(func->type_params[i], Type());
}
- return VisitExpr(TypeSubst(func->body, subst), ll);
+ std::vector<Time> args_time;
+ for (const auto& v : pv) {
+ args_time.push_back(v->created_time);
+ }
+ CHECK_GT(func_map_.count(func), 0);
+ FuncId fid = func_map_.at(func);
+ auto recurse = [&]() {
+ TimeFrame tf(this, fid, args_time);
+ return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll);
+ };
+ if (time_map_.count(fid) == 0) {
+ return recurse();
+ } else {
+ /* We check to see that at least one argument decrease
+ * with respect to all previous invocation.
+ * The depth of the recursion is bounded by
+ * the sum of the time of all argument at the first call.
+ */
+ bool can_recurse = false;
+ std::vector<Time>& min_time = time_map_.at(fid);
+ CHECK_EQ(args_time.size(), min_time.size());
+ for (size_t i = 0; i < args_time.size(); ++i) {
+ if (args_time[i] < min_time[i]) {
+ can_recurse = true;
+ }
+ args_time[i] = std::min(args_time[i], min_time[i]);
+ }
+ if (can_recurse) {
+ return recurse();
+ } else {
+ std::vector<Expr> dyn;
+ for (const auto& v : pv) {
+ dyn.push_back(v->dynamic);
+ }
+ return NoStatic(ll->Push(CallNode::make(var, dyn, attrs, type_args)));
+ }
+ }
});
};
- Expr dyn = store_.Extend<Expr>([&]() {
+ }
+
+
+ Expr VisitFuncDynamic(const Function& func, const Func& f) {
+ return store_.Extend<Expr>([&]() {
store_.Invalidate();
return FunctionNode::make(func->params, LetList::With([&](LetList* ll) {
std::vector<PStatic> pv;
return f(pv, Attrs(), type_args, ll)->dynamic;
}), func->ret_type, func->type_params, func->attrs);
});
- return HasStatic(MkSFunc(f), ll->Push(dyn));
+ }
+
+ PStatic VisitFunc(const Function& func, LetList* ll) {
+ Var v = VarNode::make("x", Type());
+ Func f = VisitFuncStatic(func, v);
+ Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func))));
+ // TODO(@M.K.): we seems to reduce landin knot into letrec.
+ // restore letrec support across whole relay.
+ return HasStatic(MkSFunc(f),
+ ll->Push(v, VisitFuncDynamic(u_func, f)));
+ }
+
+ PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final {
+ return VisitFunc(GetRef<Function>(op), ll);
}
Expr Reflect(const PStatic& st) {
return Reify(executor_(fused_infered), ll);
}
- Func ConstEvaluateFunc(const Expr& expr, LetList* ll) {
+ Func ConstEvaluateFunc(const Expr& expr) {
+ CHECK_EQ(FreeVars(expr).size(), 0);
return [=](const std::vector<PStatic>& pv,
const Attrs& attrs,
const tvm::Array<Type>& type_args,
for (const PStatic& ps : pv) {
ns_args.push_back(ps->dynamic);
}
- PStatic ns = NoStatic(CallNode::make(expr, ns_args, attrs, type_args));
+ PStatic ns = NoStatic(ll->Push(CallNode::make(expr, ns_args, attrs, type_args)));
if (StatefulOp(expr)) {
return ns;
}
}
PStatic VisitExpr_(const OpNode* op, LetList* ll) final {
- return HasStatic(MkSFunc(ConstEvaluateFunc(GetRef<Expr>(op), ll)), GetRef<Expr>(op));
+ return HasStatic(MkSFunc(ConstEvaluateFunc(GetRef<Expr>(op))), GetRef<Expr>(op));
}
PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final {
CHECK_NE(op->constructor->tag, -1);
CHECK_NE(scn->constructor->tag, -1);
if (op->constructor->tag == scn->constructor->tag) {
- // todo(M.K.): should use ptr equality but it is broken
CHECK_EQ(op->patterns.size(), scn->fields.size());
MatchStatus current_match_status = MatchStatus::Match;
for (size_t i = 0; i < op->patterns.size(); ++i) {
}
}
+ void InitializeFuncId(const Expr& e) {
+ struct InitializeFuncIdVisitor : ExprVisitor, PatternVisitor {
+ PartialEvaluator* pe;
+ explicit InitializeFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { }
+
+ void VisitExpr_(const FunctionNode* op) final {
+ Function f = GetRef<Function>(op);
+ CHECK_EQ(pe->func_map_.count(f), 0);
+ pe->func_map_.insert({f, pe->func_map_.size()});
+ VisitExpr(f->body);
+ }
+
+ void VisitPattern(const Pattern& p) final {
+ PatternVisitor::VisitPattern(p);
+ }
+ };
+ InitializeFuncIdVisitor(this).VisitExpr(e);
+ }
+
+ Expr RegisterFuncId(const Expr& e) {
+ struct RegisterFuncIdVisitor : ExprVisitor, PatternVisitor {
+ PartialEvaluator* pe;
+ explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { }
+
+ void VisitExpr_(const CallNode* op) final {
+ if (op->op.same_as(WithFuncIdOp())) {
+ CHECK_EQ(op->args.size(), 1);
+ CHECK(op->attrs.defined());
+ CHECK(op->attrs.as<WithFuncIdAttrs>());
+ Function f = AsFunc(op->args[0]);
+ FuncId fid = op->attrs.as<WithFuncIdAttrs>()->fid;
+ if (pe->func_map_.count(f) != 0) {
+ CHECK_EQ(pe->func_map_.at(f), fid);
+ }
+ pe->func_map_.insert({f, fid});
+ }
+ ExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const FunctionNode* op) final {
+ Function f = GetRef<Function>(op);
+ CHECK_GT(pe->func_map_.count(f), 0);
+ ExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitPattern(const Pattern& p) final {
+ PatternVisitor::VisitPattern(p);
+ }
+ };
+ RegisterFuncIdVisitor(this).VisitExpr(e);
+ return e;
+ }
+
+ Expr AnnotateFuncId(const Expr& e) {
+ struct AnnotateFuncIdMutator : ExprMutator, PatternMutator {
+ PartialEvaluator* pe;
+ explicit AnnotateFuncIdMutator(PartialEvaluator* pe) : pe(pe) { }
+
+ Expr VisitExpr_(const FunctionNode* op) final {
+ Function f = GetRef<Function>(op);
+ CHECK_GT(pe->func_map_.count(f), 0);
+ return MkWithFuncId(ExprMutator::VisitExpr_(op), pe->func_map_.at(f));
+ }
+
+ Pattern VisitPattern(const Pattern& p) final {
+ return PatternMutator::VisitPattern(p);
+ }
+
+ Var VisitVar(const Var& v) final {
+ return v;
+ }
+ };
+ return AnnotateFuncIdMutator(this).VisitExpr(e);
+ }
+
private:
Environment env_;
+ Module mod_;
+ std::unordered_map<GlobalVar, PStatic, NodeHash, NodeEqual> gv_map_;
+ /*! Termination checking is done as follows:
+ * We have finitely many FunctionIds.
+ * Each FunctionId maps to a class of semantically equivalent function (ignoring type),
+ * as both TypeSubst and DeDup create semantically equivalent function.
+ * We partially map each FunctionId to a std::vector<Time>,
+ * denoting the minimal TimeFrame of each argument of the function.
+ * Every time we try to inline a Function,
+ * we make sure it either does not have a vector<Time>, which means this is the initial call,
+ * or some argument has a lesser time, which means some earlier argument is passed in.
+ * In any case, we remap the mapping to a minimal vector<Time> across all previous invocations
+ * when we PE inside the Function body.
+ * Termination is guaranteed because the creation time of at least one argument will decrease every call.
+ */
+ std::unordered_map<Function, FuncId, NodeHash, NodeEqual> func_map_;
+ std::unordered_map<FuncId, std::vector<Time> > time_map_;
Store store_;
DLContext context_ = CPUContext();
FInterpreter executor_ = CPUInterpreter();
};
-Var DeDupVar(const Var& v) {
- return VarNode::make(v->name_hint(), v->type_annotation);
-}
-
-TypeVar DeDupTypeVar(const TypeVar& tv) {
- return TypeVarNode::make(tv->var->name_hint, tv->kind);
-}
-
/*! \brief Use a fresh Id for every Var to make the result well-formed. */
Expr DeDup(const Expr& e) {
- class DeDupMutator : public ExprMutator, public PatternMutator {
+ class DeDupMutator : public TypeMutator,
+ public ExprMutator,
+ public PatternMutator {
public:
+ TypeVar Fresh(const TypeVar& tv) {
+ TypeVar ret = TypeVarNode::make(tv->var->name_hint, tv->kind);
+ type_rename_[tv] = ret;
+ return ret;
+ }
+
Var Fresh(const Var& v) {
- Var ret = DeDupVar(v);
+ Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation));
rename_[v] = ret;
return ret;
}
}
Expr VisitExpr_(const LetNode* op) final {
- return LetNode::make(Fresh(op->var), VisitExpr(op->value), VisitExpr(op->body));
+ Var v = Fresh(op->var);
+ return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
+ }
+
+ Type VisitType(const Type& t) final {
+ return t.defined() ? TypeMutator::VisitType(t) : t;
}
Expr VisitExpr_(const FunctionNode* op) final {
+ tvm::Array<TypeVar> type_params;
+ for (const TypeVar& type_param : op->type_params) {
+ type_params.push_back(Fresh(type_param));
+ }
tvm::Array<Var> params;
for (const Var& param : op->params) {
params.push_back(Fresh(param));
}
return FunctionNode::make(params,
VisitExpr(op->body),
- op->ret_type,
- op->type_params,
+ VisitType(op->ret_type),
+ type_params,
op->attrs);
}
return PatternMutator::VisitPattern(p);
}
+ Clause VisitClause(const Clause& c) final {
+ Pattern pat = VisitPattern(c->lhs);
+ return ClauseNode::make(pat, VisitExpr(c->rhs));
+ }
+
+ Type VisitType_(const TypeVarNode* op) final {
+ TypeVar v = GetRef<TypeVar>(op);
+ return type_rename_.count(v) != 0 ? type_rename_.at(v) : v;
+ }
+
Var VisitVar(const Var& v) final {
return Fresh(v);
}
private:
std::unordered_map<Var, Var, NodeHash, NodeEqual> rename_;
+ std::unordered_map<TypeVar, TypeVar, NodeHash, NodeEqual> type_rename_;
};
- return DeDupMutator().VisitExpr(e);
+
+ Expr ret = DeDupMutator().VisitExpr(e);
+ CHECK_EQ(FreeVars(ret).size(), FreeVars(e).size());
+ return ret;
}
/*! \brief Remap multiple Var sharing the same Id into the same Var. */
return RemapMutator().VisitExpr(e);
}
-Expr PartialEval(const Expr& e) {
+Expr StripWithFuncId(const Expr& e) {
+ struct StripWithFuncIdMutator : ExprMutator, PatternMutator {
+ Expr VisitExpr_(const CallNode* op) final {
+ if (op->op.same_as(WithFuncIdOp())) {
+ CHECK_EQ(op->args.size(), 1);
+ return VisitExpr(op->args[0]);
+ } else {
+ return ExprMutator::VisitExpr_(op);
+ }
+ }
+
+ Pattern VisitPattern(const Pattern& p) final {
+ return PatternMutator::VisitPattern(p);
+ }
+
+ Var VisitVar(const Var& v) final {
+ return v;
+ }
+ };
+ return StripWithFuncIdMutator().VisitExpr(e);
+}
+
+Expr PostProcess(const Expr& e) {
+ return StripWithFuncId(DeDup(Remap(e)));
+}
+
+Expr PartialEval(const Expr& e, const Module& m) {
return TransformF([&](const Expr& e) {
return LetList::With([&](LetList* ll) {
- PartialEvaluator pe(FreeVars(e));
- return Remap(DeDup(pe.VisitExpr(e, ll)->dynamic));
+ PartialEvaluator pe(FreeVars(e), m);
+ pe.InitializeFuncId(e);
+ return PostProcess(pe.VisitExpr(e, ll)->dynamic);
});
}, e);
}
Pass PartialEval() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
- return Downcast<Function>(PartialEval(f));
+ return Downcast<Function>(PartialEval(f, m));
};
return CreateFunctionPass(pass_func, 1, "PartialEvaluate", {});
}
import numpy as np
import tvm
from tvm import relay
-from tvm.relay.ir_pass import partial_evaluate, dead_code_elimination
-from tvm.relay.ir_pass import gradient, alpha_equal, infer_type
+from tvm.relay.ir_pass import partial_evaluate, alpha_equal, infer_type, dead_code_elimination
+from tvm.relay.ir_pass import gradient
from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
from tvm.relay.prelude import Prelude
from tvm.relay import create_executor
-
from nose.tools import nottest
+from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate
+from tvm.relay import TensorType, Tuple, If, Module, Clause, PatternConstructor, PatternVar, Match
+from tvm.relay import GlobalVar, Call, Type
+from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr
def check_eval(expr, expected_result, mod=None, rtol=1e-07):
ctx = tvm.context("llvm", 0)
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
-def dcpe(expr):
- return dead_code_elimination(partial_evaluate(expr))
+def dcpe(expr, mod=None):
+ return dead_code_elimination(partial_evaluate(expr, mod=mod), inline_once=True)
def test_tuple():
- t = relay.TypeVar("t")
- x = relay.Var("x", t)
- body = relay.TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1)
- f = relay.Function([x], body, None, [t])
+ t = TypeVar("t")
+ x = Var("x", t)
+ body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1)
+ f = Function([x], body, None, [t])
assert alpha_equal(dcpe(f), relay.Function([x], x, None, [t]))
-@nottest
def test_const_inline():
- # TODO(MK): fix me
- d = relay.Var("d")
- double = relay.Function([d], d + d)
- orig = double(relay.const(4.0))
- assert alpha_equal(dcpe(double(relay.const(4.0))), relay.const(8.0))
+ d = Var("d")
+ double = Function([d], d + d)
+ orig = double(const(4.0))
+ assert alpha_equal(dcpe(orig), const(8.0))
def test_ref():
r = relay.Var("r")
x = relay.Var("x")
body = relay.RefRead(r)
- body = relay.Let(x, relay.RefWrite(r, relay.RefRead(r) * relay.RefRead(r)), body)
- body = relay.Let(r, relay.RefCreate(d), body)
- square = relay.Function([d], body)
- assert alpha_equal(dcpe(square), relay.Function([d], d * d))
+ body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body)
+ body = Let(r, RefCreate(d), body)
+ square = Function([d], body)
+ assert alpha_equal(dcpe(square), Function([d], d * d))
+
+
+def test_empty_ad():
+ shape = (10, 10)
+ dtype = "float32"
+ t = TensorType(shape, dtype)
+ d = Var("d", t)
+ f = Function([d], d)
+ g = dcpe(gradient(f))
+ expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])]))
+ assert alpha_equal(g, expected)
-@nottest
def test_ad():
- # TODO(MK): fix me
shape = (10, 10)
dtype = "float32"
- t = relay.TensorType(shape, dtype)
- d = relay.Var("d", t)
- f = relay.Function([d], d * d)
+ t = TensorType(shape, dtype)
+ d = Var("d", t)
+ f = Function([d], d * d)
g = dcpe(gradient(f))
m = d * d
- o = relay.op.ones_like(m)
- grad = relay.op.zeros_like(d) + relay.op.collapse_sum_like(o * d, d) + relay.op.collapse_sum_like(o * d, d)
- expected = relay.Function([d], relay.Tuple([m, relay.Tuple([grad])]))
+ x = relay.Var("x")
+ o = op.ones_like(x)
+ x1 = relay.Var("x1")
+ grad = op.zeros_like(d) + op.collapse_sum_like(x1 * d, d) + op.collapse_sum_like(x1 * d, d)
+ body = Tuple([x, Tuple([grad])])
+ body = relay.Let(x1, o, body)
+ expected = Function([d], relay.Let(x, m, body))
assert alpha_equal(g, expected)
def test_if_ref():
shape = ()
dtype = "bool"
- t = relay.TensorType(shape, dtype)
- d = relay.Var("d", t)
- r = relay.Var("r")
- update = relay.Function([], relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r)))
- u = relay.Var("u")
- body = relay.If(d, u(), u())
- eff = relay.Var("eff")
- body = relay.Let(eff, body, relay.RefRead(r))
- f = relay.Function([d], relay.Let(r, relay.RefCreate(relay.const(1)), relay.Let(u, update, body)))
+ t = TensorType(shape, dtype)
+ d = Var("d", t)
+ r = Var("r")
+ update = Function([], RefWrite(r, RefRead(r) + RefRead(r)))
+ u = Var("u")
+ body = If(d, u(), u())
+ eff = Var("eff")
+ body = Let(eff, body, RefRead(r))
+ f = Function([d], Let(r, RefCreate(const(1)), Let(u, update, body)))
f = infer_type(f)
pe_f = infer_type(partial_evaluate(f))
ex = create_executor()
- f_res = ex.evaluate(f)(relay.const(True))
- pe_f_res = ex.evaluate(pe_f)(relay.const(True))
+ f_res = ex.evaluate(f)(const(True))
+ pe_f_res = ex.evaluate(pe_f)(const(True))
np.testing.assert_allclose(f_res.asnumpy(), 2 * np.ones_like(f_res.asnumpy()))
np.testing.assert_allclose(pe_f_res.asnumpy(), 2 * np.ones_like(pe_f_res.asnumpy()))
def test_function_invalidate():
shape = ()
dtype = "bool"
- t = relay.TensorType(shape, dtype)
- d = relay.Var("d", t)
- r = relay.Var("r")
- fetch = relay.Function([], relay.RefRead(r))
- fet = relay.Var("fetch")
- fet_obscured = relay.Var("fetch_obscured")
- u = relay.Var("u")
- body = relay.If(d, fet_obscured(), fet_obscured())
- body = relay.Let(u, relay.RefWrite(r, relay.const(1)), body)
- body = relay.Let(fet_obscured, relay.If(d, fet, fet), body)
- body = relay.Let(fet, fetch, body)
- body = relay.Let(r, relay.RefCreate(relay.const(0)), body)
- f = relay.Function([d], body)
+ t = TensorType(shape, dtype)
+ d = Var("d", t)
+ r = Var("r")
+ fetch = Function([], RefRead(r))
+ fet = Var("fetch")
+ fet_obscured = Var("fetch_obscured")
+ u = Var("u")
+ body = If(d, fet_obscured(), fet_obscured())
+ body = Let(u, RefWrite(r, const(1)), body)
+ body = Let(fet_obscured, If(d, fet, fet), body)
+ body = Let(fet, fetch, body)
+ body = Let(r, RefCreate(const(0)), body)
+ f = Function([d], body)
f = infer_type(f)
pe_f = infer_type(partial_evaluate(f))
ex = create_executor()
- f_res = ex.evaluate(f)(relay.const(True))
- pe_f_res = ex.evaluate(pe_f)(relay.const(True))
+ f_res = ex.evaluate(f)(const(True))
+ pe_f_res = ex.evaluate(pe_f)(const(True))
np.testing.assert_allclose(f_res.asnumpy(), np.ones_like(f_res.asnumpy()))
np.testing.assert_allclose(pe_f_res.asnumpy(), np.ones_like(pe_f_res.asnumpy()))
def test_head_cons():
- mod = relay.Module()
+ mod = Module()
p = Prelude(mod)
def hd_impl():
- a = relay.TypeVar("a")
- x = relay.Var("x", p.l(a))
- y = relay.Var("y")
- z = relay.Var("z")
- cons_case = relay.Clause(relay.PatternConstructor(p.cons,
- [relay.PatternVar(y),
- relay.PatternVar(z)]),
- y)
- return relay.Function([x], relay.Match(x, [cons_case]), a, [a])
- t = relay.TypeVar("t")
- x = relay.Var("x", t)
- hd = relay.Var("hd")
- body = relay.Let(hd, hd_impl(), hd(p.cons(x, p.nil())))
- f = relay.Function([x], body, None, [t])
+ a = TypeVar("a")
+ x = Var("x", p.l(a))
+ y = Var("y")
+ z = Var("z")
+ cons_case = Clause(PatternConstructor(p.cons,
+ [PatternVar(y),
+ PatternVar(z)]),
+ y)
+ y = Var("y")
+ z = Var("z")
+ return Function([x], Match(x, [cons_case]), a, [a])
+ t = TypeVar("t")
+ x = Var("x", t)
+ hd = Var("hd")
+ body = Let(hd, hd_impl(), hd(p.cons(x, p.nil())))
+ f = Function([x], body, None, [t])
f = infer_type(f, mod=mod)
res = dcpe(f)
- assert alpha_equal(res, relay.Function([x], x, t, [t]))
+ assert alpha_equal(res, Function([x], x, t, [t]))
+
+
+def test_map():
+ mod = Module()
+ p = Prelude(mod)
+ f = Var("f")
+ orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil()))))
+ expected = p.cons(f(const(1)), p.cons(f(const(2)), p.cons(f(const(3)), p.nil())))
+ assert alpha_equal(dcpe(orig, mod=mod), expected)
+
+
+def test_loop():
+ mod = Module()
+ t = TypeVar("t")
+ x = Var("x", t)
+ loop = GlobalVar("loop")
+ mod[loop] = Function([x], loop(x), t, [t])
+ res = dcpe(loop(const(1)), mod=mod)
+ expected = Call(loop, [const(1)], None, [None])
+ assert alpha_equal(res, expected)
+
+
+def test_swap_loop():
+ mod = Module()
+ p = Prelude(mod)
+ add_nat_definitions(p)
+ nat = p.nat()
+ x = Var("x", nat)
+ y = Var("y", nat)
+ loop = GlobalVar("loop")
+ mod[loop] = Function([x, y], loop(y, x), nat)
+ prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2))
+ res = dcpe(prog, mod=mod)
+ assert alpha_equal(prog, res)
+
+
+def test_abs_diff():
+ # TODO(@M.K.): refactor using tuple pattern (not yet implemented)
+ mod = Module()
+ p = Prelude(mod)
+ add_nat_definitions(p)
+ nat = p.nat()
+ x = Var("x", nat)
+ y = Var("y", nat)
+ xp = Var("x'", nat)
+ yp = Var("y'", nat)
+ diff = GlobalVar("diff")
+ y_z_case = Clause(PatternConstructor(p.z, []), x)
+ y_s_case = Clause(PatternConstructor(p.s, [PatternVar(yp)]), diff(yp, xp))
+ x_z_case = Clause(PatternConstructor(p.z, []), y)
+ x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case]))
+ mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case]))
+ orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3))
+ res = dcpe(orig, mod=mod)
+ assert alpha_equal(res, make_nat_expr(p, 4))
+
+
+def test_match_nat_id():
+ mod = Module()
+ p = Prelude(mod)
+ add_nat_definitions(p)
+ nat = p.nat()
+ x = Var("x", nat)
+ y = Var("y", nat)
+ nat_id = GlobalVar("nat_id")
+ z_case = Clause(PatternConstructor(p.z, []), p.z())
+ s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y))
+ mod[nat_id] = Function([x], Match(x, [z_case, s_case]))
+ orig = nat_id(make_nat_expr(p, 3))
+ res = dcpe(orig, mod=mod)
+ assert alpha_equal(res, make_nat_expr(p, 3))
+
+
+def test_nat_id():
+ mod = Module()
+ p = Prelude(mod)
+ add_nat_definitions(p)
+ nat = p.nat()
+ x = Var("x", nat)
+ y = Var("y", nat)
+ nat_id = GlobalVar("nat_id")
+ mod[nat_id] = Function([x], x)
+ orig = nat_id(make_nat_expr(p, 3))
+ res = dcpe(orig, mod=mod)
+ assert alpha_equal(res, make_nat_expr(p, 3))
+
+
+def test_global_match_nat_id():
+ mod = Module()
+ p = Prelude(mod)
+ add_nat_definitions(p)
+ nat = p.nat()
+ x = Var("x", nat)
+ z_case = Clause(PatternConstructor(p.z, []), p.z())
+ s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x))
+ orig = Match(make_nat_expr(p, 3), [z_case, s_case])
+ res = dcpe(orig, mod=mod)
+ assert alpha_equal(res, make_nat_expr(p, 3))
+
+
+def test_double():
+ mod = Module()
+ p = Prelude(mod)
+ add_nat_definitions(p)
+ orig = p.double(make_nat_expr(p, 3))
+ res = dcpe(orig, mod=mod)
+ assert alpha_equal(res, make_nat_expr(p, 6))
if __name__ == '__main__':
+ test_empty_ad()
test_tuple()
test_const_inline()
test_ref()
test_if_ref()
test_function_invalidate()
test_head_cons()
+ test_map()
+ test_loop()
+ test_swap_loop()
+ test_abs_diff()
+ test_double()
+ test_nat_id()
+ test_global_match_nat_id()
+ test_match_nat_id()