* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
*/
/*!
- * Copyright (c) 2018 by Contributors
+ * Copyright (c) 2019 by Contributors
* \file src/tvm/relay/expr_mutator.cc
* \brief A wrapper around ExprFunctor which functionally updates the AST.
*
* the cost of using functional updates.
*/
#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/pattern_functor.h>
#include "type_functor.h"
namespace tvm {
});
// Implement bind.
-class ExprBinder : public ExprMutator {
+class ExprBinder : public ExprMutator, PatternMutator {
public:
explicit ExprBinder(const tvm::Map<Var, Expr>& args_map)
: args_map_(args_map) {
}
}
+ Pattern VisitPattern(const Pattern& p) final {
+ return PatternMutator::VisitPattern(p);
+ }
+
+ Clause VisitClause(const Clause& c) final {
+ Pattern pat = VisitPattern(c->lhs);
+ return ClauseNode::make(pat, VisitExpr(c->rhs));
+ }
+
+ Var VisitVar(const Var& v) final {
+ return Downcast<Var>(VisitExpr(v));
+ }
+
private:
const tvm::Map<Var, Expr>& args_map_;
};
Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
if (const FunctionNode* func = expr.as<FunctionNode>()) {
- Expr new_body = ExprBinder(args_map).Mutate(func->body);
+ Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
Array<Var> new_params;
for (Var param : func->params) {
if (!args_map.count(param)) {
func->type_params,
func->attrs);
} else {
- return ExprBinder(args_map).Mutate(expr);
+ return ExprBinder(args_map).VisitExpr(expr);
}
}
* 3: The generated code reuses bindings (although they are not shadowed),
* so we have to deduplicate them.
*
- * 4: In the generated code, multiple VarNode might have same Id.
+ * 4: In the generated code, as it call TypeSubst, multiple VarNode might have same Id.
* While it is permitted, most pass use NodeHash for Var,
* and having multiple VarNode for same Id break them.
* Thus we remap them to a single Id for now.
}
using Func = std::function<PStatic(const std::vector<PStatic>&,
- const Attrs&,
- const Array<Type>&,
- LetList*)>;
+ const Attrs&,
+ const Array<Type>&,
+ LetList*)>;
struct SFuncNode : StaticNode {
Func func;
void Insert(const Var& v, const PStatic& ps) {
CHECK(ps.defined());
+ CHECK_EQ(env_.back().locals.count(v), 0);
env_.back().locals[v] = ps;
}
/*!
* \brief As our store require rollback, we implement it as a frame.
- * every time we need to copy the store, a new frame is insert.
- * every time we roll back, a frame is popped.
+ *
+ * Every time we need to copy the store, a new frame is insert.
+ * Every time we roll back, a frame is popped.
*/
struct StoreFrame {
std::unordered_map<const SRefNode*, PStatic> store;
- /*! \brief on unknown effect, history_valid is set to true to signal above frame is outdated */
+ /*!
+ * \brief On unknown effect, history_valid is set to true to signal above frame is outdated.
+ *
+ * It only outdate the frame above it, but not the current frame.
+ */
bool history_valid = true;
explicit StoreFrame(const std::unordered_map<const SRefNode*, PStatic>& store) : store(store) { }
StoreFrame() = default;
}
void Insert(const SRefNode* r, const PStatic& ps) {
+ CHECK(r);
store_.back().store[r] = ps;
}
PStatic Lookup(const SRefNode* r) {
auto rit = store_.rbegin();
while (rit != store_.rend()) {
- if (!rit->history_valid) {
- return PStatic();
- }
if (rit->store.find(r) != rit->store.end()) {
return rit->store.find(r)->second;
}
+ if (!rit->history_valid) {
+ return PStatic();
+ }
++rit;
}
return PStatic();
}
void Invalidate() {
- store_.back().history_valid = false;
+ StoreFrame sf;
+ sf.history_valid = false;
+ store_.push_back(sf);
}
private:
store_->store_.push_back(StoreFrame());
}
~StoreFrameContext() {
+ // push one history valid frame off.
+ while (!store_->store_.back().history_valid) {
+ store_->store_.pop_back();
+ }
store_->store_.pop_back();
}
};
class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>,
public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> {
public:
- PartialEvaluator(const tvm::Array<Var>& free_vars,
- const Module& mod) :
- mod_(mod) {
- for (const Var& v : free_vars) {
- env_.Insert(v, NoStatic(v));
- }
- }
+ PartialEvaluator(const Module& mod) : mod_(mod) { }
PStatic VisitExpr(const Expr& e, LetList* ll) final {
PStatic ret = ExprFunctor<PStatic(const Expr&, LetList*)>::VisitExpr(e, ll);
return env_.Lookup(GetRef<Var>(op));
}
- PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final {
- GlobalVar gv = GetRef<GlobalVar>(op);
+ PStatic VisitGlobalVar(const GlobalVar& gv) {
+ CHECK(mod_.defined());
if (gv_map_.count(gv) == 0) {
- if (mod_.defined()) {
- Function func = mod_->Lookup(gv);
- InitializeFuncId(func);
- Func f = VisitFuncStatic(func, gv);
- gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
- func = AsFunc(PostProcess(VisitFuncDynamic(func, f)));
- mod_->Update(gv, func);
- } else {
- gv_map_.insert({gv, NoStatic(gv)});
- }
+ Function func = mod_->Lookup(gv);
+ InitializeFuncId(func);
+ Func f = VisitFuncStatic(func, gv);
+ gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
+ func = AsFunc(PostProcess(VisitFuncDynamic(func, f)));
+ mod_->Update(gv, func);
}
return gv_map_.at(gv);
}
+ PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final {
+ return VisitGlobalVar(GetRef<GlobalVar>(op));
+ }
+
PStatic VisitExpr_(const LetNode* op, LetList* ll) final {
env_.Insert(op->var, VisitExpr(op->value, ll));
return VisitExpr(op->body, ll);
subst.Set(func->type_params[i], type_args[i]);
}
for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
- subst.Set(func->type_params[i], Type());
+ subst.Set(func->type_params[i], IncompleteTypeNode::make(kType));
}
std::vector<Time> args_time;
for (const auto& v : pv) {
};
}
-
Expr VisitFuncDynamic(const Function& func, const Func& f) {
return store_.Extend<Expr>([&]() {
- store_.Invalidate();
- return FunctionNode::make(func->params, LetList::With([&](LetList* ll) {
- std::vector<PStatic> pv;
- for (const auto& v : func->params) {
- pv.push_back(NoStatic(v));
- }
- tvm::Array<Type> type_args;
- for (const auto& tp : func->type_params) {
- type_args.push_back(tp);
- }
- return f(pv, Attrs(), type_args, ll)->dynamic;
- }), func->ret_type, func->type_params, func->attrs);
- });
+ store_.Invalidate();
+ return FunctionNode::make(func->params,
+ LetList::With([&](LetList* ll) {
+ std::vector<PStatic> pv;
+ for (const auto& v : func->params) {
+ pv.push_back(NoStatic(v));
+ }
+ tvm::Array<Type> type_args;
+ for (const auto& tp : func->type_params) {
+ type_args.push_back(tp);
+ }
+ return f(pv, Attrs(), type_args, ll)->dynamic;
+ }), func->ret_type, func->type_params, func->attrs);
+ });
}
PStatic VisitFunc(const Function& func, LetList* ll) {
Module PartialEval(const Module& m) {
CHECK(m->entry_func.defined());
- auto func = m->Lookup(m->entry_func);
- Expr ret =
- TransformF([&](const Expr& e) {
- return LetList::With([&](LetList* ll) {
- relay::partial_eval::PartialEvaluator pe(FreeVars(e), m);
- pe.InitializeFuncId(e);
- return relay::partial_eval::PostProcess(pe.VisitExpr(e, ll)->dynamic);
- });
- }, func);
- CHECK(ret->is_type<FunctionNode>());
- m->Update(m->entry_func, Downcast<Function>(ret));
+ relay::partial_eval::PartialEvaluator pe(m);
+ std::vector<GlobalVar> gvs;
+ for (const auto& p : m->functions) {
+ gvs.push_back(p.first);
+ }
+ for (const auto& gv : gvs) {
+ pe.VisitGlobalVar(gv);
+ }
return m;
}