[Relay][Module] Refactor the way we interface between different modules of Relay...
authorJared Roesch <roeschinc@gmail.com>
Thu, 12 Sep 2019 03:39:56 +0000 (22:39 -0500)
committerZhi <5145158+zhiics@users.noreply.github.com>
Thu, 12 Sep 2019 03:39:56 +0000 (20:39 -0700)
* Module refactor

* Add load module

* Add support for idempotent import

* Tweak load paths

* Move path around

* Expose C++ import functions in Python

* Fix import

* Add doc string

* Fix

* Fix lint

* Fix lint

* Fix test failure

* Add type solver

* Fix lint

12 files changed:
include/tvm/relay/expr.h
include/tvm/relay/module.h
include/tvm/relay/type.h
python/tvm/relay/__init__.py
python/tvm/relay/module.py
python/tvm/relay/prelude.py
python/tvm/relay/std/prelude.rly [moved from python/tvm/relay/prelude.rly with 100% similarity]
src/relay/ir/expr_functor.cc
src/relay/ir/module.cc
src/relay/pass/type_infer.cc
src/relay/pass/type_solver.cc
src/relay/pass/type_solver.h

index c5cd6bb..b1b8d6a 100644 (file)
@@ -575,6 +575,7 @@ std::string PrettyPrint(const NodeRef& node);
 std::string AsText(const NodeRef& node,
                    bool show_meta_data = true,
                    runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_EXPR_H_
index 3496c88..ee9b487 100644 (file)
@@ -33,6 +33,7 @@
 #include <string>
 #include <vector>
 #include <unordered_map>
+#include <unordered_set>
 
 namespace tvm {
 namespace relay {
@@ -185,6 +186,23 @@ class ModuleNode : public RelayNode {
    */
   TVM_DLL void Update(const Module& other);
 
+  /*!
+   * \brief Import Relay code from the file at path.
+   * \param path The path of the Relay code to import.
+   *
+   * \note The path resolution behavior is standard,
+   * if abosolute will be the absolute file, if
+   * relative it will be resovled against the current
+   * working directory.
+   */
+  TVM_DLL void Import(const std::string& path);
+
+  /*!
+   * \brief Import Relay code from the file at path, relative to the standard library.
+   * \param path The path of the Relay code to import.
+   */
+  TVM_DLL void ImportFromStd(const std::string& path);
+
   /*! \brief Construct a module from a standalone expression.
    *
    * Allows one to optionally pass a global function map and
@@ -222,6 +240,11 @@ class ModuleNode : public RelayNode {
    * for convenient access
    */
   std::unordered_map<int32_t, Constructor> constructor_tag_map_;
+
+  /*! \brief The files previously imported, required to ensure
+      importing is idempotent for each module.
+   */
+  std::unordered_set<std::string> import_set_;
 };
 
 struct Module : public NodeRef {
@@ -235,6 +258,12 @@ struct Module : public NodeRef {
   using ContainerType = ModuleNode;
 };
 
+/*! \brief Parse Relay source into a module.
+ * \param source A string of Relay source code.
+ * \param source_name The name of the source file.
+ * \return A Relay module.
+ */
+Module FromText(const std::string& source, const std::string& source_name);
 
 }  // namespace relay
 }  // namespace tvm
index d509fde..16e3678 100644 (file)
@@ -410,6 +410,12 @@ class TypeReporterNode : public Node {
    */
   TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0;
 
+  /*!
+   * \brief Retrieve the current global module.
+   * \return The global module.
+   */
+  TVM_DLL virtual Module GetModule() = 0;
+
   // solver is not serializable.
   void VisitAttrs(tvm::AttrVisitor* v) final {}
 
index 092cd01..ceb98c4 100644 (file)
@@ -17,6 +17,7 @@
 # pylint: disable=wildcard-import, redefined-builtin, invalid-name
 """The Relay IR namespace containing the IR definition and compiler."""
 from __future__ import absolute_import
+import os
 from sys import setrecursionlimit
 from ..api import register_func
 from . import base
index e0511a2..57980dd 100644 (file)
 # under the License.
 # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
 """A global module storing everything needed to interpret or compile a Relay program."""
+import os
 from .base import register_relay_node, RelayNode
+from .. import register_func
 from .._ffi import base as _base
 from . import _make
 from . import _module
 from . import expr as _expr
 from . import ty as _ty
 
+__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std")
+
+@register_func("tvm.relay.std_path")
+def _std_path():
+    global __STD_PATH__
+    return __STD_PATH__
+
 @register_relay_node
 class Module(RelayNode):
     """The global Relay module containing collection of functions.
@@ -202,3 +211,9 @@ class Module(RelayNode):
         funcs = functions if functions is not None else {}
         defs = type_defs if type_defs is not None else {}
         return _module.Module_FromExpr(expr, funcs, defs)
+
+    def _import(self, file_to_import):
+        return _module.Module_Import(self, file_to_import)
+
+    def import_from_std(self, file_to_import):
+        return _module.Module_ImportFromStd(self, file_to_import)
index f9a7d3d..d05b669 100644 (file)
 # under the License.
 # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
 """A prelude containing useful global functions and ADT definitions."""
-import os
 from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type
 from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const
 from .op.tensor import add, subtract, equal
 from .adt import Constructor, TypeData, Clause, Match
 from .adt import PatternConstructor, PatternVar, PatternWildcard, PatternTuple
-from .parser import fromtext
-__PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__))
 from .module import Module
 
 class Prelude:
@@ -479,12 +476,10 @@ class Prelude:
         Parses the portions of the Prelude written in Relay's text format and adds
         them to the module.
         """
-        prelude_file = os.path.join(__PRELUDE_PATH__, "prelude.rly")
-        with open(prelude_file) as prelude:
-            prelude = fromtext(prelude.read())
-            self.mod.update(prelude)
-            self.id = self.mod.get_global_var("id")
-            self.compose = self.mod.get_global_var("compose")
+        # TODO(@jroesch): we should remove this helper when we port over prelude
+        self.mod.import_from_std("prelude.rly")
+        self.id = self.mod.get_global_var("id")
+        self.compose = self.mod.get_global_var("compose")
 
 
     def __init__(self, mod=None):
index da9f7b8..6a2db6b 100644 (file)
@@ -444,7 +444,6 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
   }
 }
 
-
 TVM_REGISTER_API("relay._expr.Bind")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
     NodeRef input = args[0];
index dbaea7f..2601f35 100644 (file)
@@ -26,6 +26,8 @@
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/transform.h>
 #include <sstream>
+#include <fstream>
+#include <unordered_set>
 
 namespace tvm {
 namespace relay {
@@ -38,6 +40,9 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
   auto n = make_node<ModuleNode>();
   n->functions = std::move(global_funcs);
   n->type_definitions = std::move(global_type_defs);
+  n->global_type_var_map_ = {};
+  n->global_var_map_ = {};
+  n->constructor_tag_map_ = {};
 
   for (const auto& kv : n->functions) {
     // set global var map
@@ -85,6 +90,7 @@ void ModuleNode::AddUnchecked(const GlobalVar& var,
 }
 
 GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
+  CHECK(global_type_var_map_.defined());
   auto it = global_type_var_map_.find(name);
   CHECK(it != global_type_var_map_.end())
     << "Cannot find global type var " << name << " in the Module";
@@ -162,6 +168,7 @@ void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) {
   // set global type var map
   CHECK(!global_type_var_map_.count(var->var->name_hint))
     << "Duplicate global type definition name " << var->var->name_hint;
+
   global_type_var_map_.Set(var->var->name_hint, var);
   RegisterConstructors(var, type);
 
@@ -241,6 +248,40 @@ Module ModuleNode::FromExpr(
   return mod;
 }
 
+void ModuleNode::Import(const std::string& path) {
+  LOG(INFO) << "Importing: " << path;
+  if (this->import_set_.count(path) == 0) {
+    this->import_set_.insert(path);
+    std::fstream src_file(path, std::fstream::in);
+    std::string file_contents {
+      std::istreambuf_iterator<char>(src_file),
+      std::istreambuf_iterator<char>() };
+    auto mod_to_import = FromText(file_contents, path);
+
+    for (auto func : mod_to_import->functions) {
+      this->Add(func.first, func.second, false);
+    }
+
+    for (auto type : mod_to_import->type_definitions) {
+      this->AddDef(type.first, type.second);
+    }
+  }
+}
+
+void ModuleNode::ImportFromStd(const std::string& path) {
+  auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path");
+  CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
+  std::string std_path = (*f)();
+  return this->Import(std_path + "/" + path);
+}
+
+Module FromText(const std::string& source, const std::string& source_name) {
+  auto* f = tvm::runtime::Registry::Get("relay.fromtext");
+  CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
+  Module mod = (*f)(source, source_name);
+  return mod;
+}
+
 TVM_REGISTER_NODE_TYPE(ModuleNode);
 
 TVM_REGISTER_API("relay._make.Module")
@@ -320,6 +361,16 @@ TVM_REGISTER_API("relay._module.Module_Update")
   mod->Update(from);
 });
 
+TVM_REGISTER_API("relay._module.Module_Import")
+.set_body_typed<void(Module, std::string)>([](Module mod, std::string path) {
+  mod->Import(path);
+});
+
+TVM_REGISTER_API("relay._module.Module_ImportFromStd")
+.set_body_typed<void(Module, std::string)>([](Module mod, std::string path) {
+  mod->ImportFromStd(path);
+});;
+
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<ModuleNode>(
   [](const ModuleNode *node, tvm::IRPrinter *p) {
index e8bdc09..5b9b25b 100644 (file)
@@ -108,7 +108,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
 
   explicit TypeInferencer(Module mod, GlobalVar current_func)
       : mod_(mod), current_func_(current_func),
-        err_reporter(), solver_(current_func, &this->err_reporter) {
+        err_reporter(), solver_(current_func, mod, &this->err_reporter) {
+    CHECK(mod.defined()) << "internal error: Module must be set in the type inferencer";
   }
 
   // inference the type of expr.
@@ -790,36 +791,22 @@ void EnsureCheckedType(const Expr& e) {
   AllCheckTypePopulated().VisitExpr(e);
 }
 
-Expr InferType(const Expr& expr, const Module& mod_ref) {
-  if (!mod_ref.defined()) {
-    Module mod = ModuleNode::FromExpr(expr);
-    // NB(@jroesch): By adding the expression to the module we will
-    // type check it anyway; afterwards we can just recover type
-    // from the type-checked function to avoid doing unnecessary work.
-
-    Function func = mod->Lookup("main");
-
-    // FromExpr wraps a naked expression as a function, we will unbox
-    // it here.
-    if (expr.as<FunctionNode>()) {
-      return std::move(func);
-    } else {
-      return func->body;
-    }
-  } else {
-    auto e = TypeInferencer(mod_ref, mod_ref->GetGlobalVar("main")).Infer(expr);
-    CHECK(WellFormed(e));
-    auto free_tvars = FreeTypeVars(e, mod_ref);
-    CHECK(free_tvars.size() == 0)
-      << "Found unbound type variables in " << e << ": " << free_tvars;
-    EnsureCheckedType(e);
-    return e;
-  }
+Expr InferType(const Expr& expr, const Module& mod) {
+  auto main = mod->GetGlobalVar("main");
+  auto inferencer = TypeInferencer(mod, main);
+  auto e = inferencer.Infer(expr);
+  CHECK(WellFormed(e));
+  auto free_tvars = FreeTypeVars(e, mod);
+  CHECK(free_tvars.size() == 0)
+    << "Found unbound type variables in " << e << ": " << free_tvars;
+  EnsureCheckedType(e);
+  return e;
 }
 
 Function InferType(const Function& func,
                    const Module& mod,
                    const GlobalVar& var) {
+  CHECK(mod.defined()) << "internal error: module must be set for type inference";
   Function func_copy = Function(make_node<FunctionNode>(*func.operator->()));
   func_copy->checked_type_ = func_copy->func_type_annotation();
   mod->AddUnchecked(var, func_copy);
index 743a4c7..31edd3b 100644 (file)
@@ -61,6 +61,10 @@ class TypeSolver::Reporter : public TypeReporterNode {
     location = ref;
   }
 
+  TVM_DLL Module GetModule() final {
+    return this->solver_->module_;
+  }
+
  private:
   /*! \brief The location to report unification errors at. */
   mutable NodeRef location;
@@ -526,10 +530,13 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> {
 };
 
 // constructor
-TypeSolver::TypeSolver(const GlobalVar &current_func, ErrorReporter* err_reporter)
-  : reporter_(make_node<Reporter>(this)),
-    current_func(current_func),
-    err_reporter_(err_reporter) {
+TypeSolver::TypeSolver(const GlobalVar& current_func, const Module& module,
+                       ErrorReporter* err_reporter)
+    : reporter_(make_node<Reporter>(this)),
+      current_func(current_func),
+      err_reporter_(err_reporter),
+      module_(module) {
+  CHECK(module_.defined()) << "internal error: module must be defined";
 }
 
 // destructor
@@ -653,18 +660,22 @@ TVM_REGISTER_API("relay._analysis._test_type_solver")
     using runtime::PackedFunc;
     using runtime::TypedPackedFunc;
     ErrorReporter *err_reporter = new ErrorReporter();
-    auto solver = std::make_shared<TypeSolver>(GlobalVarNode::make("test"), err_reporter);
+    auto module = ModuleNode::make({}, {});
+    auto dummy_fn_name = GlobalVarNode::make("test");
+    module->Add(dummy_fn_name, FunctionNode::make({}, TupleNode::make({}), Type(), {}, {}));
+    auto solver = std::make_shared<TypeSolver>(dummy_fn_name, module, err_reporter);
 
-    auto mod = [solver, err_reporter](std::string name) -> PackedFunc {
+    auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc {
       if (name == "Solve") {
         return TypedPackedFunc<bool()>([solver]() {
             return solver->Solve();
           });
       } else if (name == "Unify") {
-        return TypedPackedFunc<Type(Type, Type)>([solver, err_reporter](Type lhs, Type rhs) {
+        return TypedPackedFunc<Type(Type, Type)>(
+          [module, solver, err_reporter](Type lhs, Type rhs) {
             auto res = solver->Unify(lhs, rhs, lhs);
             if (err_reporter->AnyErrors()) {
-              err_reporter->RenderErrors(ModuleNode::make({}, {}), true);
+              err_reporter->RenderErrors(module, true);
             }
             return res;
           });
index 2857963..4a6d2cf 100644 (file)
@@ -63,7 +63,7 @@ using common::LinkedList;
  */
 class TypeSolver {
  public:
-  TypeSolver(const GlobalVar& current_func, ErrorReporter* err_reporter);
+  TypeSolver(const GlobalVar& current_func, const Module& _mod, ErrorReporter* err_reporter);
   ~TypeSolver();
   /*!
    * \brief Add a type constraint to the solver.
@@ -179,6 +179,8 @@ class TypeSolver {
   GlobalVar current_func;
   /*! \brief Error reporting. */
   ErrorReporter* err_reporter_;
+  /*! \brief The module. */
+  Module module_;
 
   /*!
    * \brief GetTypeNode that is corresponds to t.