fix (#3417)
author雾雨魔理沙 <lolisa@marisa.moe>
Mon, 24 Jun 2019 04:40:14 +0000 (21:40 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Mon, 24 Jun 2019 04:40:14 +0000 (21:40 -0700)
include/tvm/relay/module.h
src/relay/backend/vm/lambda_lift.cc
src/relay/ir/module.cc

index 3966a62..638f759 100644 (file)
@@ -123,42 +123,42 @@ class ModuleNode : public RelayNode {
    * \param str The unique string specifying the global variable.
    * \returns The global variable.
    */
-  TVM_DLL GlobalVar GetGlobalVar(const std::string& str);
+  TVM_DLL GlobalVar GetGlobalVar(const std::string& str) const;
 
   /*!
    * \brief Look up a global function by its name.
    * \param str The unique string specifying the global variable.
    * \returns The global variable.
    */
-  TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str);
+  TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const;
 
   /*!
    * \brief Lookup a global function by its variable.
    * \param var The global var to lookup.
    * \returns The function named by the variable argument.
    */
-  TVM_DLL Function Lookup(const GlobalVar& var);
+  TVM_DLL Function Lookup(const GlobalVar& var) const;
 
   /*!
    * \brief Lookup a global function by its string name
    * \param name The name of the function.
    * \returns The function named by the argument.
    */
-  TVM_DLL Function Lookup(const std::string& name);
+  TVM_DLL Function Lookup(const std::string& name) const;
 
   /*!
    * \brief Lookup a global type definition by its variable.
    * \param var The var of the global type definition.
    * \return The type definition.
    */
-  TVM_DLL TypeData LookupDef(const GlobalTypeVar& var);
+  TVM_DLL TypeData LookupDef(const GlobalTypeVar& var) const;
 
   /*!
    * \brief Lookup a global type definition by its name.
    * \param var The name of the global type definition.
    * \return The type definition.
    */
-  TVM_DLL TypeData LookupDef(const std::string& var);
+  TVM_DLL TypeData LookupDef(const std::string& var) const;
 
   /*!
    * \brief Update the functions inside this environment by
index a55a927..668c024 100644 (file)
@@ -112,7 +112,7 @@ struct LambdaLifter : ExprMutator {
     CHECK(lifted_func.defined());
 
     auto name = GenerateName(lifted_func);
-    auto global = module_->GetGlobalVar(name);
+    auto global = GlobalVarNode::make(name);
 
     // Add the lifted function to the module.
     module_->Add(global, lifted_func);
index 6b5fee8..58f614a 100644 (file)
@@ -57,15 +57,11 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
   return Module(n);
 }
 
-GlobalVar ModuleNode::GetGlobalVar(const std::string& name) {
+GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const {
   auto it = global_var_map_.find(name);
-  if (it == global_var_map_.end()) {
-    auto gvar = GlobalVarNode::make(name);
-    global_var_map_.Set(name, gvar);
-    return gvar;
-  } else {
-    return (*it).second;
-  }
+  CHECK(it != global_var_map_.end())
+    << "Cannot find global var " << name << " in the Module";
+  return (*it).second;
 }
 
 void ModuleNode::AddUnchecked(const GlobalVar& var,
@@ -84,7 +80,7 @@ void ModuleNode::AddUnchecked(const GlobalVar& var,
   global_var_map_.Set(var->name_hint, var);
 }
 
-GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) {
+GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
   auto it = global_type_var_map_.find(name);
   CHECK(it != global_type_var_map_.end())
     << "Cannot find global type var " << name << " in the Module";
@@ -137,26 +133,26 @@ void ModuleNode::Remove(const GlobalVar& var) {
   gvar_node->data.erase(var->name_hint);
 }
 
-Function ModuleNode::Lookup(const GlobalVar& var) {
+Function ModuleNode::Lookup(const GlobalVar& var) const {
   auto it = functions.find(var);
   CHECK(it != functions.end())
       << "There is no definition of " << var->name_hint;
   return (*it).second;
 }
 
-Function ModuleNode::Lookup(const std::string& name) {
+Function ModuleNode::Lookup(const std::string& name) const {
   GlobalVar id = this->GetGlobalVar(name);
   return this->Lookup(id);
 }
 
-TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) {
+TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const {
   auto it = type_definitions.find(var);
   CHECK(it != type_definitions.end())
     << "There is no definition of " << var->var->name_hint;
   return (*it).second;
 }
 
-TypeData ModuleNode::LookupDef(const std::string& name) {
+TypeData ModuleNode::LookupDef(const std::string& name) const {
   GlobalTypeVar id = this->GetGlobalTypeVar(name);
   return this->LookupDef(id);
 }