[REFACTOR][IR] Migrate IRModule ObjectRef to not-null (#5654)
authorANSHUMAN TRIPATHY <anshuman.t@huawei.com>
Fri, 22 May 2020 21:32:59 +0000 (03:02 +0530)
committerGitHub <noreply@github.com>
Fri, 22 May 2020 21:32:59 +0000 (14:32 -0700)
include/tvm/ir/module.h
python/tvm/relay/transform/transform.py
src/relay/analysis/feature.cc
src/relay/analysis/match_exhaustion.cc
src/relay/backend/compile_engine.h
src/relay/backend/graph_runtime_codegen.cc
src/relay/transforms/gradient.cc
src/tir/transforms/split_host_device.cc

index ba9a62a..9317f1e 100644 (file)
@@ -285,7 +285,7 @@ class IRModule : public ObjectRef {
                             Map<GlobalTypeVar, TypeData> type_definitions = {},
                             std::unordered_set<String> import_set = {});
   /*! \brief default constructor */
-  IRModule() {}
+  IRModule() : IRModule(Map<GlobalVar, BaseFunc>()) {}
   /*!
    * \brief constructor
    * \param n The object pointer.
@@ -299,12 +299,6 @@ class IRModule : public ObjectRef {
   }
 
   /*!
-   * \brief Construct an empty module.
-   *
-   * \returns The constructed module
-   */
-  static IRModule Empty() { return IRModule(Map<GlobalVar, BaseFunc>()); }
-  /*!
    * \brief Construct a module from a standalone expression.
    *
    * Allows one to optionally pass a global function map and
@@ -330,6 +324,10 @@ class IRModule : public ObjectRef {
 
   /*! \brief Declare the container type. */
   using ContainerType = IRModuleNode;
+
+  /*! \brief Declare whether Ref is nullable. */
+  static constexpr bool _type_is_nullable = false;
+
   // allow copy on write.
   TVM_DEFINE_OBJECT_REF_COW_METHOD(IRModuleNode);
 };
index c58c679..1d17c42 100644 (file)
@@ -656,7 +656,8 @@ def to_cps(func, mod=None):
     result: tvm.relay.Function
       The output function.
     """
-    return _ffi_api.to_cps(func, mod)
+    use_mod = mod if mod is not None else tvm.ir.IRModule()
+    return _ffi_api.to_cps(func, use_mod)
 
 
 def un_cps(func):
index 9e94459..6be956c 100644 (file)
@@ -96,8 +96,11 @@ FeatureSet DetectFeature(const IRModule& mod) {
   return fs;
 }
 
-Array<Integer> PyDetectFeature(const Expr& expr, const IRModule& mod) {
-  FeatureSet fs = DetectFeature(expr) + DetectFeature(mod);
+Array<Integer> PyDetectFeature(const Expr& expr, const Optional<IRModule>& mod) {
+  FeatureSet fs = DetectFeature(expr);
+  if (mod.defined()) {
+    fs = fs + DetectFeature(mod.value());
+  }
   return static_cast<Array<Integer>>(fs);
 }
 
index 96dab6b..e852c40 100644 (file)
@@ -305,11 +305,8 @@ Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod) {
 
 // expose for testing only
 TVM_REGISTER_GLOBAL("relay.analysis.unmatched_cases")
-    .set_body_typed([](const Match& match, const IRModule& mod_ref) {
-      IRModule call_mod = mod_ref;
-      if (!call_mod.defined()) {
-        call_mod = IRModule({}, {});
-      }
+    .set_body_typed([](const Match& match, const Optional<IRModule>& mod_ref) {
+      IRModule call_mod = mod_ref.defined() ? mod_ref.value() : IRModule({}, {});
       return UnmatchedCases(match, call_mod);
     });
 
index 9abe80c..a5f3f63 100644 (file)
@@ -82,7 +82,7 @@ struct CachedFuncNode : public Object {
   /*! \brief The schedule to the function */
   te::Schedule schedule;
   /*! \brief The lowered functions to support the function. */
-  IRModule funcs = IRModule::Empty();
+  IRModule funcs = IRModule();
 
   /*! \brief Parameter usage states in the shape function. */
   tvm::Array<Integer> shape_func_param_states;
index c8ec1bf..85d439b 100644 (file)
@@ -207,7 +207,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<G
 
     for (auto& kv : lowered_funcs_) {
       if (ret.lowered_funcs.count(kv.first) == 0) {
-        ret.lowered_funcs.Set(kv.first, IRModule::Empty());
+        ret.lowered_funcs.Set(kv.first, IRModule());
       }
       auto& mod = ret.lowered_funcs[kv.first];
       mod->Update(kv.second);
@@ -395,7 +395,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<G
     CCacheKey key = (*pf0)(func, target);
     CachedFunc lowered_func = (*pf1)(compile_engine_, key);
     if (!lowered_funcs_.count(target->str())) {
-      lowered_funcs_[target->str()] = IRModule::Empty();
+      lowered_funcs_[target->str()] = IRModule();
     }
     lowered_funcs_[target->str()]->Update(lowered_func->funcs);
     return GraphAddCallNode(op, _GetUniqueName(lowered_func->func_name), lowered_func->func_name);
index afe5568..7801c03 100644 (file)
@@ -68,7 +68,7 @@ Type WithGradientType(const Type&);
 /*! return an expression that represent differentiation of e (according to WithGradientType).
  *  This version only work on first order code without control flow.
  */
-Expr FirstOrderGradient(const Expr& e, const IRModule& mod);
+Expr FirstOrderGradient(const Expr& e, const Optional<IRModule>& mod);
 
 Type WithGradientType(const Type& t) {
   // TODO(M.K.): stricter checking
@@ -78,9 +78,11 @@ Type WithGradientType(const Type& t) {
 }
 
 //! \brief if the expression is a GlobalVar, transform to it's expression.
-Expr DeGlobal(const IRModule& mod, const Expr& e) {
-  if (const auto* x = e.as<GlobalVarNode>()) {
-    BaseFunc base_func = mod->Lookup(GetRef<GlobalVar>(x));
+Expr DeGlobal(const Optional<IRModule>& mod, const Expr& e) {
+  const auto* x = e.as<GlobalVarNode>();
+
+  if (mod.defined() && (x)) {
+    BaseFunc base_func = mod.value()->Lookup(GetRef<GlobalVar>(x));
     if (auto* n = base_func.as<FunctionNode>()) {
       return n->body;
     } else {
@@ -214,7 +216,7 @@ Type GradRetType(const Function& f) {
   return TupleType({f->ret_type, TupleType(vt)});
 }
 
-Expr FirstOrderGradient(const Expr& re, const IRModule& mod) {
+Expr FirstOrderGradient(const Expr& re, const Optional<IRModule>& mod) {
   // Currently we first remove any global functions for the first
   // order case.
   auto e = DeGlobal(mod, re);
@@ -482,7 +484,7 @@ bool MissingGrad(const Expr& e) {
   return false;
 }
 
-Expr Gradient(const Expr& re, const IRModule& mod) {
+Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
   auto e = DeGlobal(mod, re);
   auto f = e.as<FunctionNode>();
   CHECK(f) << "input need to be a function";
index 9bdb0e2..98577a7 100644 (file)
@@ -275,7 +275,7 @@ Pass SplitHostDevice() {
   auto pass_func = [](IRModule mod, PassContext ctx) {
     IRModuleNode* mod_ptr = mod.CopyOnWrite();
     auto* func_dict = mod_ptr->functions.CopyOnWrite();
-    IRModule device_mod = IRModule::Empty();
+    IRModule device_mod = IRModule();
 
     for (auto& kv : func_dict->data) {
       if (kv.second->IsInstance<PrimFuncNode>()) {