From 2c5c4da697753ca79ea1551cc91c3072cecbbbb1 Mon Sep 17 00:00:00 2001 From: Logan Weber <36520469+weberlo@users.noreply.github.com> Date: Fri, 15 Nov 2019 14:12:52 -0800 Subject: [PATCH] [Relay][VM][Interpreter] Enable first-class constructors in VM and interpreter via eta expansion (#4218) * Fix constructor pretty printing * Make Module::HasDef name consistent with API * Add VM constructor compilation via eta expansion * Lint * Fix CI * Fix failing test * Address comment * Retrigger CI * Retrigger CI --- include/tvm/relay/module.h | 14 +-- include/tvm/relay/transform.h | 9 +- python/tvm/relay/std/prelude.rly | 15 +-- python/tvm/relay/transform.py | 15 ++- src/relay/backend/interpreter.cc | 11 ++ src/relay/backend/vm/compiler.cc | 4 + src/relay/backend/vm/lambda_lift.cc | 22 ++-- src/relay/ir/alpha_equal.cc | 2 +- src/relay/ir/module.cc | 12 +-- src/relay/ir/pretty_printer.cc | 6 +- src/relay/pass/eta_expand.cc | 159 ++++++++++++++++++++++------- src/relay/pass/type_infer.cc | 2 +- tests/python/relay/test_ir_text_printer.py | 22 ++++ tests/python/relay/test_pass_eta_expand.py | 75 +++++++++++--- 14 files changed, 276 insertions(+), 92 deletions(-) diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 1ef7ca8..0d3f46c 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -145,6 +145,13 @@ class ModuleNode : public RelayNode { TVM_DLL bool ContainGlobalVar(const std::string& name) const; /*! + * \brief Check if the global_type_var_map_ contains a global type variable. + * \param name The variable name. + * \returns true if contains, otherise false. + */ + TVM_DLL bool ContainGlobalTypeVar(const std::string& name) const; + + /*! * \brief Lookup a global function by its variable. * \param str The unique string specifying the global variable. * \returns The global variable. @@ -199,13 +206,6 @@ class ModuleNode : public RelayNode { TVM_DLL TypeData LookupDef(const std::string& var) const; /*! - * \brief Check if a global type definition exists - * \param var The name of the global type definition. - * \return Whether the definition exists. - */ - TVM_DLL bool HasDef(const std::string& var) const; - - /*! * \brief Look up a constructor by its tag. * \param tag The tag for the constructor. * \return The constructor object. diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 10de087..ddadbe4 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -552,17 +552,20 @@ TVM_DLL Pass Legalize(const std::string& legalize_map_attr_name = "FTVMLegalize" TVM_DLL Pass CanonicalizeCast(); /*! - * \brief Add abstraction over a function + * \brief Add abstraction over a constructor or global variable bound to a function. * * For example: `square` is transformed to - * `fun x -> square x`. + * `fn (%x: int32) -> int32 { square(x) }`. * * See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion * for more details. * + * \param expand_constructor Whether to expand constructors. + * \param expand_global_var Whether to expand global variables. + * * \return The pass. */ -TVM_DLL Pass EtaExpand(); +TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var); /*! * \brief Print the IR for a module to help debugging. diff --git a/python/tvm/relay/std/prelude.rly b/python/tvm/relay/std/prelude.rly index a5c2c9f..fa05d1a 100644 --- a/python/tvm/relay/std/prelude.rly +++ b/python/tvm/relay/std/prelude.rly @@ -158,13 +158,9 @@ def @sum(%xs: List[Tensor[(), int32]]) { /* * Concatenates two lists. */ + def @concat[A](%xs: List[A], %ys: List[A]) -> List[A] { - let %updater = fn(%x: A, %xss: List[A]) -> List[A] { - Cons(%x, %xss) - }; - @foldr(%updater, %ys, %xs) - // TODO(weberlo): write it like below, once VM constructor compilation is fixed - // @foldr(Cons, %ys, %xs) + @foldr(Cons, %ys, %xs) } /* @@ -199,12 +195,7 @@ def @zip[A, B](%xs: List[A], %ys: List[B]) -> List[(A, B)] { * Reverses a list. */ def @rev[A](%xs: List[A]) -> List[A] { - let %updater = fn(%xss: List[A], %x: A) -> List[A] { - Cons(%x, %xss) - }; - @foldl(%updater, Nil, %xs) - // TODO(weberlo): write it like below, once VM constructor compilation is fixed - // @foldl(@flip(Cons), Nil, %xs) + @foldl(@flip(Cons), Nil, %xs) } /* diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 0a7512a..540c1f5 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -529,15 +529,23 @@ def ToCPS(expr, mod=None): return _transform.to_cps(expr, mod) -def EtaExpand(): - """Add abstraction over a function +def EtaExpand(expand_constructor=False, expand_global_var=False): + """Add abstraction over a constructor or global variable bound to a function + + Parameters + ---------- + expand_constructor: bool + Whether to expand constructors. + + expand_global_var: bool + Whether to expand global variables. Returns ------- ret: tvm.relay.Pass The registered pass that eta expands an expression. """ - return _transform.EtaExpand() + return _transform.EtaExpand(expand_constructor, expand_global_var) def ToGraphNormalForm(): @@ -959,6 +967,7 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None): return create_function_pass(pass_func) return create_function_pass + @function_pass(opt_level=1) class ChangeBatch: """ diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 01693e5..4528358 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -789,6 +790,16 @@ CreateInterpreter( Module mod, DLContext context, Target target) { + if (mod.defined()) { + // eta expand to support constructors in argument position + transform::Sequential seq({ + transform::EtaExpand( + /* expand_constructor */ true, /* expand_global_var */ false)}); + transform::PassContext pass_ctx = transform::PassContext::Current(); + tvm::With ctx(pass_ctx); + mod = seq(mod); + } + auto intrp = std::make_shared(mod, context, target); auto packed = [intrp](Expr expr) { auto f = DetectFeature(expr); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 06705b4..c38ca1a 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -874,6 +874,10 @@ Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) pass_seqs.push_back(transform::Legalize()); } + // eta expand to support constructors in argument position + pass_seqs.push_back(transform::EtaExpand( + /* expand_constructor */ true, /* expand_global_var */ false)); + pass_seqs.push_back(transform::SimplifyInference()); PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { Expr expr = args[0]; diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 6290ef7..6ef31e6 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -61,8 +61,8 @@ Function MarkClosure(const Function& func) { * We will lift a function out into a global which takes the set of the free * vars and then return the new created function. */ -struct LambdaLifter : ExprMutator { - Module module_; +class LambdaLifter : public ExprMutator { + public: explicit LambdaLifter(const Module& module) : module_(module) {} Expr VisitExpr_(const FunctionNode* func_node) final { @@ -100,8 +100,8 @@ struct LambdaLifter : ExprMutator { // The "inner" function should be used to generate the // code for the closure. Function lifted_func; - if (free_vars.size() == 0) { - lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, free_type_vars); + if (free_vars.size() == 0 && free_type_vars.size() == 0) { + lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, body->type_params); } else { lifted_func = FunctionNode::make(free_vars, body, func->func_type_annotation(), free_type_vars); @@ -114,8 +114,15 @@ struct LambdaLifter : ExprMutator { auto name = GenerateName(lifted_func); auto global = GlobalVarNode::make(name); - // Add the lifted function to the module. - module_->Add(global, lifted_func); + if (module_->ContainGlobalVar(name)) { + const auto existing_func = module_->Lookup(name); + CHECK(AlphaEqual(lifted_func, existing_func)) << "lifted function hash collision"; + // If an identical function already exists, use its global var. + global = module_->GetGlobalVar(name); + } else { + // Add the lifted function to the module. + module_->Add(global, lifted_func); + } if (free_vars.size() == 0) { return std::move(global); @@ -145,6 +152,9 @@ struct LambdaLifter : ExprMutator { } return module_; } + + private: + Module module_; }; } // namespace vm diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 0dbcf99..df91f79 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -69,7 +69,7 @@ class AlphaEqualHandler: } if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false; for (const auto& p : lhsm->type_definitions) { - if (!rhsm->HasDef(p.first->var->name_hint) || + if (!rhsm->ContainGlobalTypeVar(p.first->var->name_hint) || !Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) { return false; } diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 960c28f..3bd8d59 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -68,6 +68,10 @@ bool ModuleNode::ContainGlobalVar(const std::string& name) const { return global_var_map_.find(name) != global_var_map_.end(); } +bool ModuleNode::ContainGlobalTypeVar(const std::string& name) const { + return global_type_var_map_.find(name) != global_type_var_map_.end(); +} + GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const { auto it = global_var_map_.find(name); CHECK(it != global_var_map_.end()) @@ -239,11 +243,6 @@ TypeData ModuleNode::LookupDef(const std::string& name) const { return this->LookupDef(id); } -bool ModuleNode::HasDef(const std::string& name) const { - auto it = global_type_var_map_.find(name); - return it != global_type_var_map_.end(); -} - Constructor ModuleNode::LookupTag(const int32_t tag) { auto it = constructor_tag_map_.find(tag); CHECK(it != constructor_tag_map_.end()) @@ -336,7 +335,8 @@ TVM_REGISTER_API("relay._module.Module_Add") } else if (val->IsInstance()) { GlobalVar gv = Downcast(val); auto mod_copy = Module(make_node(*mod.operator->())); - mod_copy = transform::EtaExpand()(mod_copy); + mod_copy = transform::EtaExpand( + /* expand_constructor */ false, /* expand_global_var */ true)(mod_copy); auto func = mod_copy->Lookup(gv->name_hint); mod->Add(var, Downcast(func), update); } else { diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index b2a8396..f42069b 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -669,7 +669,7 @@ class PrettyPrinter : Doc VisitExpr_(const ConstructorNode* n) final { Doc doc; doc << n->name_hint; - if (n->inputs.size() != 0) { + if (in_adt_def_ && n->inputs.size() != 0) { doc << "("; std::vector inputs; for (Type input : n->inputs) { @@ -775,6 +775,7 @@ class PrettyPrinter : } Doc VisitType_(const TypeDataNode* node) final { + in_adt_def_ = true; Doc doc; doc << "type " << Print(node->header); @@ -802,6 +803,7 @@ class PrettyPrinter : adt_body << ","; } doc << Brace(adt_body); + in_adt_def_ = false; return doc; } @@ -876,6 +878,8 @@ class PrettyPrinter : TextMetaDataContext meta_; /*! \brief counter of temporary variable */ size_t temp_var_counter_{0}; + /*! \brief whether the printer is currently in an ADT definition */ + bool in_adt_def_; /*! \brief arena for dependency graph */ common::Arena arena_; /*! \brief dependency graph of the expr */ diff --git a/src/relay/pass/eta_expand.cc b/src/relay/pass/eta_expand.cc index a5d0487..dca08cc 100644 --- a/src/relay/pass/eta_expand.cc +++ b/src/relay/pass/eta_expand.cc @@ -20,57 +20,144 @@ /*! * \file eta_expand.cc * - * \brief Add abstraction over a function. For example, abs will become (fun x -> abs x). + * \brief Add an abstraction over constructors and/or global variables bound to a function. * */ -#include #include +#include +#include +#include "../ir/type_functor.h" namespace tvm { namespace relay { +namespace eta_expand { + +/*! + * \brief mutator to replace type variables with fresh ones, while maintaining alpha equality + */ +class TypeVarReplacer : public TypeMutator { + public: + TypeVarReplacer() : replace_map_({}) {} -Expr EtaExpand(const Expr& e, const Module& mod) { - tvm::Array original_params; - tvm::Array params; - tvm::Array args; - tvm::Array original_type_params; - Type ret_type; - - if (e->IsInstance()) { - auto gvar_node = e.as(); - auto func = mod->Lookup(GetRef(gvar_node)); - original_params = func->params; - original_type_params = func->type_params; - ret_type = func->ret_type; - } else { - CHECK(e->IsInstance()); - auto func = GetRef(e.as()); - original_params = func->params; - original_type_params = func->type_params; - ret_type = func->ret_type; + Type VisitType_(const TypeVarNode* type_var_node) final { + const auto type_var = GetRef(type_var_node); + if (replace_map_.find(type_var) == replace_map_.end()) { + replace_map_[type_var] = TypeVarNode::make("A", Kind::kType); + } + return replace_map_[type_var]; } - for (size_t i = 0; i < original_params.size(); ++i) { - auto var = VarNode::make("a", original_params[i]->type_annotation); - params.push_back(var); - args.push_back(var); + private: + /*! \brief variable replacement map to remap old type vars to fresh ones */ + std::unordered_map replace_map_; +}; + +/*! + * \brief mutator to perform eta expansion on all functions in a module + */ +class EtaExpander : public ExprMutator { + public: + explicit EtaExpander(const Module& mod, bool expand_constructor, bool expand_global_var) + : mod_(mod), + type_var_replacer_(TypeVarReplacer()), + expand_constructor_(expand_constructor), + expand_global_var_(expand_global_var) { + CHECK(expand_constructor || expand_global_var) + << "must expand at least one language feature"; } - auto new_func = - FunctionNode::make(args, CallNode::make(e, params), ret_type, original_type_params); + Module Expand() { + for (GlobalVar global_var : mod_->GetGlobalVars()) { + const Function func = mod_->Lookup(global_var); + const Function new_func = Downcast(VisitExpr(func)); + mod_->Update(global_var, new_func); + } + return mod_; + } - return std::move(new_func); -} + Expr VisitExpr_(const CallNode* call) final { + // we don't need to expand constructors when they are being called, so we + // prevent them being visited here + Expr new_op = call->op; + if (!call->op.as()) { + new_op = VisitExpr(new_op); + } + tvm::Array new_args; + for (const auto& arg : call->args) { + new_args.push_back(VisitExpr(arg)); + } + return CallNode::make(new_op, new_args, call->attrs, call->type_args); + } + + Expr VisitExpr_(const ConstructorNode* cons_node) final { + Constructor cons = GetRef(cons_node); + if (!expand_constructor_) { + return std::move(cons); + } + // NOTE: we only reach this case if the constructor is not being applied to any arguments + tvm::Array params; + for (const auto& type : cons->inputs) { + Type param_type = type_var_replacer_.VisitType(type); + params.push_back(VarNode::make("eta_expand_param", param_type)); + } + tvm::Array type_params; + TypeData adt_def = mod_->LookupDef(cons->belong_to); + for (const auto& type_var : adt_def->type_vars) { + type_params.push_back(type_var_replacer_.VisitType(type_var)); + } + Expr body = CallNode::make(cons, params, Attrs()); + Type ret_type = TypeCallNode::make(cons->belong_to, type_params); + + return FunctionNode::make( + Downcast>(params), + body, + ret_type, + Downcast>(type_params)); + } + + Expr VisitExpr_(const GlobalVarNode* gvar_node) final { + GlobalVar gvar = GetRef(gvar_node); + if (!expand_global_var_) { + return std::move(gvar); + } + + const auto func = mod_->Lookup(gvar); + tvm::Array params; + tvm::Array args; + for (size_t i = 0; i < func->params.size(); ++i) { + auto var = VarNode::make("eta_expand_param", func->params[i]->type_annotation); + params.push_back(var); + args.push_back(var); + } + + return FunctionNode::make( + args, + CallNode::make(gvar, params), + func->ret_type, + func->type_params); + } + + private: + /*! \brief reference to module being expanded */ + const Module mod_; + /*! \brief type variable replacer */ + TypeVarReplacer type_var_replacer_; + /*! \brief whether to expand constructor nodes */ + bool expand_constructor_; + /*! \brief whether to expand global variable nodes */ + bool expand_global_var_; +}; + +} // namespace eta_expand namespace transform { -Pass EtaExpand() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { - return Downcast(EtaExpand(f, m)); - }; - Pass expanded = CreateFunctionPass(pass_func, 1, "EtaExpand", {}); - return Sequential({expanded, InferType()}); +Pass EtaExpand(bool expand_constructor, bool expand_global_var) { + runtime::TypedPackedFunc pass_func = + [=](Module mod, PassContext pc) { + return eta_expand::EtaExpander(mod, expand_constructor, expand_global_var).Expand(); + }; + return CreateModulePass(pass_func, 1, "EtaExpand", {}); } TVM_REGISTER_API("relay._transform.EtaExpand") diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index bc84bdd..9d68781 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -653,7 +653,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { } Expr VisitExpr_(const ConstructorNode* op) final { - return GetRef(op); + return AttachCheckedType(op); } Expr VisitExpr_(const MatchNode* op) final { diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 0d6a02e..6426bf3 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -218,6 +218,27 @@ def test_zeros(): x = relay.op.zeros([], "float32") astext(x) + +def test_unapplied_constructor(): + type_def_str = r""" +type List[A] { + Cons(A, List[A]), + Nil, +} + """ + main_def_str = r""" +def @main[A]() -> fn (A, List[A]) -> List[A] { + Cons +} + """ + mod = relay.fromtext(SEMVER + type_def_str + main_def_str) + mod_str = str(mod) + # ensure constructors are printed correctly in type definitions (with their + # signature) and as exprs (without their signature) + assert type_def_str.strip() in mod_str + assert main_def_str.strip() in mod_str + + if __name__ == "__main__": do_print[0] = True test_lstm() @@ -239,3 +260,4 @@ if __name__ == "__main__": test_let_if_scope() test_variable_name() test_call_node_order() + test_unapplied_constructor() diff --git a/tests/python/relay/test_pass_eta_expand.py b/tests/python/relay/test_pass_eta_expand.py index 73c3a4e..b9eb2a1 100644 --- a/tests/python/relay/test_pass_eta_expand.py +++ b/tests/python/relay/test_pass_eta_expand.py @@ -14,27 +14,70 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import os + +import numpy as np + +import tvm from tvm import relay -import tvm.relay.module as _module import tvm.relay.transform as _transform -def test_eta_expand_basic(): - x = relay.var('x', 'int32') - orig = relay.Function([x], x) - mod = _module.Module.from_expr(orig) - seq = _transform.Sequential([_transform.EtaExpand()]) +def test_eta_expand_global_var(): + mod = relay.fromtext(r""" + v0.0.4 + def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] { + %x + } + def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) { + @aux + } + """) + seq = _transform.Sequential([_transform.EtaExpand(expand_global_var=True)]) with _transform.PassContext(opt_level=3): mod = seq(mod) + expected = relay.fromtext(r""" + v0.0.4 + def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] { + %x + } + def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) { + fn (%x: Tensor[(), int32]) -> Tensor[(), int32] { + @aux(%x) + } + } + """) + relay.analysis.assert_graph_equal(mod['main'], expected['main']) + - got = mod["main"] +def test_eta_expand_constructor(): + mod = relay.fromtext(r""" + v0.0.4 + type List[A] { + Cons(A, List[A]), + Nil, + } + def @main[A]() -> (fn(A, List[A]) -> List[A]) { + Cons + } + """) + seq = _transform.Sequential([_transform.EtaExpand(expand_constructor=True)]) + with _transform.PassContext(opt_level=3): + mod = seq(mod) + expected = relay.fromtext(r""" + v0.0.4 + type List[A] { + Cons(A, List[A]), + Nil, + } + def @main[A]() -> (fn(A, List[A]) -> List[A]) { + fn [A](%x: A, %xs: List[A]) -> List[A] { + Cons(%x, %xs) + } + } + """) + relay.analysis.assert_graph_equal(mod['main'], expected['main']) - y = relay.var('y', 'int32') - expected = relay.Function([y], orig(y)) - gv = relay.GlobalVar("gv") - mod[gv] = expected - mod = _transform.InferType()(mod) - expected = mod["gv"] - assert(relay.analysis.alpha_equal(got, expected)) -if __name__ == "__main__": - test_eta_expand_basic() +if __name__ == '__main__': + test_eta_expand_global_var() + test_eta_expand_constructor() -- 2.7.4