[Refactor][std::string --> String] IRModule is updated with String (#5523)
authorANSHUMAN TRIPATHY <anshuman.t@huawei.com>
Thu, 7 May 2020 22:11:25 +0000 (03:41 +0530)
committerGitHub <noreply@github.com>
Thu, 7 May 2020 22:11:25 +0000 (15:11 -0700)
* [std::string --> String] IRModule is updated with String

* [1] Packedfunction updated

* [2] Lint error fixed

* [3] Remove std::string variant

include/tvm/ir/module.h
src/ir/module.cc
src/printer/relay_text_printer.cc

index b0776de..d113860 100644 (file)
@@ -28,7 +28,7 @@
 #include <tvm/ir/expr.h>
 #include <tvm/ir/function.h>
 #include <tvm/ir/adt.h>
-
+#include <tvm/node/container.h>
 #include <string>
 #include <vector>
 #include <unordered_map>
@@ -131,21 +131,21 @@ class IRModuleNode : public Object {
    * \param name The variable name.
    * \returns true if contains, otherise false.
    */
-  TVM_DLL bool ContainGlobalVar(const std::string& name) const;
+  TVM_DLL bool ContainGlobalVar(const String& name) const;
 
   /*!
    * \brief Check if the global_type_var_map_ contains a global type variable.
    * \param name The variable name.
    * \returns true if contains, otherise false.
    */
-  TVM_DLL bool ContainGlobalTypeVar(const std::string& name) const;
+  TVM_DLL bool ContainGlobalTypeVar(const String& name) const;
 
   /*!
    * \brief Lookup a global function by its variable.
    * \param str The unique string specifying the global variable.
    * \returns The global variable.
    */
-  TVM_DLL GlobalVar GetGlobalVar(const std::string& str) const;
+  TVM_DLL GlobalVar GetGlobalVar(const String& str) const;
 
   /*!
    * \brief Collect all global vars defined in this module.
@@ -158,7 +158,7 @@ class IRModuleNode : public Object {
    * \param str The unique string specifying the global variable.
    * \returns The global variable.
    */
-  TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const;
+  TVM_DLL GlobalTypeVar GetGlobalTypeVar(const String& str) const;
 
   /*!
    * \brief Collect all global type vars defined in this module.
@@ -172,7 +172,7 @@ class IRModuleNode : public Object {
    * \param cons name of the constructor
    * \returns Constructor of ADT, error if not found
    */
-  TVM_DLL Constructor GetConstructor(const std::string& adt, const std::string& cons) const;
+  TVM_DLL Constructor GetConstructor(const String& adt, const String& cons) const;
 
   /*!
    * \brief Look up a global function by its variable.
@@ -186,7 +186,7 @@ class IRModuleNode : public Object {
    * \param name The name of the function.
    * \returns The function named by the argument.
    */
-  TVM_DLL BaseFunc Lookup(const std::string& name) const;
+  TVM_DLL BaseFunc Lookup(const String& name) const;
 
   /*!
    * \brief Look up a global type definition by its variable.
@@ -200,7 +200,7 @@ class IRModuleNode : public Object {
    * \param var The name of the global type definition.
    * \return The type definition.
    */
-  TVM_DLL TypeData LookupTypeDef(const std::string& var) const;
+  TVM_DLL TypeData LookupTypeDef(const String& var) const;
 
   /*!
    * \brief Look up a constructor by its tag.
@@ -225,18 +225,18 @@ class IRModuleNode : public Object {
    * relative it will be resovled against the current
    * working directory.
    */
-  TVM_DLL void Import(const std::string& path);
+  TVM_DLL void Import(const String& path);
 
   /*!
    * \brief Import Relay code from the file at path, relative to the standard library.
    * \param path The path of the Relay code to import.
    */
-  TVM_DLL void ImportFromStd(const std::string& path);
+  TVM_DLL void ImportFromStd(const String& path);
 
   /*!
    * \brief The set of imported files.
    */
-  TVM_DLL std::unordered_set<std::string> Imports() const;
+  TVM_DLL std::unordered_set<String> Imports() const;
 
   static constexpr const char* _type_key = "IRModule";
   static constexpr const bool _type_has_method_sequal_reduce = true;
@@ -265,7 +265,7 @@ class IRModuleNode : public Object {
   /*! \brief The files previously imported, required to ensure
       importing is idempotent for each module.
    */
-  std::unordered_set<std::string> import_set_;
+  std::unordered_set<String> import_set_;
   friend class IRModule;
 };
 
@@ -283,7 +283,7 @@ class IRModule : public ObjectRef {
    */
   TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
                             Map<GlobalTypeVar, TypeData> type_definitions = {},
-                            std::unordered_set<std::string> import_set = {});
+                            std::unordered_set<String> import_set = {});
   /*! \brief default constructor */
   IRModule() {}
   /*!
@@ -329,7 +329,7 @@ class IRModule : public ObjectRef {
    * \param source_path The path to the source file.
    * \return A Relay module.
    */
-  TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path);
+  TVM_DLL static IRModule FromText(const String& text, const String& source_path);
 
   /*! \brief Declare the container type. */
   using ContainerType = IRModuleNode;
@@ -346,7 +346,7 @@ class IRModule : public ObjectRef {
  *       Use AsText if you want to store the text.
  * \sa AsText.
  */
-TVM_DLL std::string PrettyPrint(const ObjectRef& node);
+TVM_DLL String PrettyPrint(const ObjectRef& node);
 
 /*!
  * \brief Render the node as a string in the text format.
@@ -362,8 +362,8 @@ TVM_DLL std::string PrettyPrint(const ObjectRef& node);
  * \sa PrettyPrint.
  * \return The text representation.
  */
-TVM_DLL std::string AsText(const ObjectRef& node,
+TVM_DLL String AsText(const ObjectRef& node,
                            bool show_meta_data = true,
-                           runtime::TypedPackedFunc<std::string(ObjectRef)> annotate = nullptr);
+                           runtime::TypedPackedFunc<String(ObjectRef)> annotate = nullptr);
 }  // namespace tvm
 #endif  // TVM_IR_MODULE_H_
index 6262150..1be58f3 100644 (file)
@@ -40,7 +40,7 @@ namespace tvm {
 
 IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
                    tvm::Map<GlobalTypeVar, TypeData> type_definitions,
-                   std::unordered_set<std::string> import_set) {
+                   std::unordered_set<String> import_set) {
   auto n = make_object<IRModuleNode>();
   n->functions = std::move(functions);
   n->type_definitions = std::move(type_definitions);
@@ -111,15 +111,15 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const {
   reduce_temp();
 }
 
-bool IRModuleNode::ContainGlobalVar(const std::string& name) const {
+bool IRModuleNode::ContainGlobalVar(const String& name) const {
   return global_var_map_.find(name) != global_var_map_.end();
 }
 
-bool IRModuleNode::ContainGlobalTypeVar(const std::string& name) const {
+bool IRModuleNode::ContainGlobalTypeVar(const String& name) const {
   return global_type_var_map_.find(name) != global_type_var_map_.end();
 }
 
-GlobalVar IRModuleNode::GetGlobalVar(const std::string& name) const {
+GlobalVar IRModuleNode::GetGlobalVar(const String& name) const {
   auto it = global_var_map_.find(name);
   if (it == global_var_map_.end()) {
     std::ostringstream msg;
@@ -146,7 +146,7 @@ tvm::Array<GlobalVar> IRModuleNode::GetGlobalVars() const {
   return tvm::Array<GlobalVar>(global_vars);
 }
 
-GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const {
+GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const String& name) const {
   CHECK(global_type_var_map_.defined());
   auto it = global_type_var_map_.find(name);
   CHECK(it != global_type_var_map_.end())
@@ -154,7 +154,7 @@ GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const {
   return (*it).second;
 }
 
-Constructor IRModuleNode::GetConstructor(const std::string& adt, const std::string& cons) const {
+Constructor IRModuleNode::GetConstructor(const String& adt, const String& cons) const {
   TypeData typeDef = this->LookupTypeDef(adt);
   for (Constructor c : typeDef->constructors) {
     if (cons.compare(c->name_hint) == 0) {
@@ -315,7 +315,7 @@ BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const {
   return (*it).second;
 }
 
-BaseFunc IRModuleNode::Lookup(const std::string& name) const {
+BaseFunc IRModuleNode::Lookup(const String& name) const {
   GlobalVar id = this->GetGlobalVar(name);
   return this->Lookup(id);
 }
@@ -327,7 +327,7 @@ TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const {
   return (*it).second;
 }
 
-TypeData IRModuleNode::LookupTypeDef(const std::string& name) const {
+TypeData IRModuleNode::LookupTypeDef(const String& name) const {
   GlobalTypeVar id = this->GetGlobalTypeVar(name);
   return this->LookupTypeDef(id);
 }
@@ -379,7 +379,7 @@ IRModule IRModule::FromExpr(
   return mod;
 }
 
-void IRModuleNode::Import(const std::string& path) {
+void IRModuleNode::Import(const String& path) {
   if (this->import_set_.count(path) == 0) {
     this->import_set_.insert(path);
     DLOG(INFO) << "Importing: " << path;
@@ -392,18 +392,18 @@ void IRModuleNode::Import(const std::string& path) {
   }
 }
 
-void IRModuleNode::ImportFromStd(const std::string& path) {
+void IRModuleNode::ImportFromStd(const String& path) {
   auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path");
   CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
   std::string std_path = (*f)();
-  return this->Import(std_path + "/" + path);
+  this->Import(std_path + "/" + path.operator std::string());
 }
 
-std::unordered_set<std::string> IRModuleNode::Imports() const {
+std::unordered_set<String> IRModuleNode::Imports() const {
   return this->import_set_;
 }
 
-IRModule IRModule::FromText(const std::string& text, const std::string& source_path) {
+IRModule IRModule::FromText(const String& text, const String& source_path) {
   auto* f = tvm::runtime::Registry::Get("relay.fromtext");
   CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
   IRModule mod = (*f)(text, source_path);
@@ -467,7 +467,7 @@ TVM_REGISTER_GLOBAL("ir.Module_Lookup")
 });
 
 TVM_REGISTER_GLOBAL("ir.Module_Lookup_str")
-.set_body_typed([](IRModule mod, std::string var) {
+.set_body_typed([](IRModule mod, String var) {
   return mod->Lookup(var);
 });
 
@@ -477,7 +477,7 @@ TVM_REGISTER_GLOBAL("ir.Module_LookupDef")
 });
 
 TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str")
-.set_body_typed([](IRModule mod, std::string var) {
+.set_body_typed([](IRModule mod, String var) {
   return mod->LookupTypeDef(var);
 });
 
@@ -499,12 +499,12 @@ TVM_REGISTER_GLOBAL("ir.Module_Update")
 });
 
 TVM_REGISTER_GLOBAL("ir.Module_Import")
-.set_body_typed([](IRModule mod, std::string path) {
+.set_body_typed([](IRModule mod, String path) {
   mod->Import(path);
 });
 
 TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd")
-.set_body_typed([](IRModule mod, std::string path) {
+.set_body_typed([](IRModule mod, String path) {
   mod->ImportFromStd(path);
 });;
 
index bda997a..2e675c8 100644 (file)
@@ -918,22 +918,28 @@ static const char* kSemVer = "v0.0.4";
 //    - Implements AsText
 // - relay_text_printer.cc (specific printing logics for relay)
 // - tir_text_printer.cc (specific printing logics for TIR)
-std::string PrettyPrint(const ObjectRef& node) {
+String PrettyPrint(const ObjectRef& node) {
   Doc doc;
   doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node);
   return doc.str();
 }
 
-std::string AsText(const ObjectRef& node,
+String AsText(const ObjectRef& node,
                    bool show_meta_data,
-                   runtime::TypedPackedFunc<std::string(ObjectRef)> annotate) {
+                   runtime::TypedPackedFunc<String(ObjectRef)> annotate) {
   Doc doc;
   doc << kSemVer << Doc::NewLine();
-  doc << relay::RelayTextPrinter(show_meta_data, annotate).PrintFinal(node);
+  runtime::TypedPackedFunc<std::string(ObjectRef)> ftyped = nullptr;
+  if (annotate != nullptr) {
+    ftyped = runtime::TypedPackedFunc<std::string(ObjectRef)>(
+      [&annotate](const ObjectRef& expr) -> std::string {
+        return annotate(expr);
+      });
+  }
+  doc << relay::RelayTextPrinter(show_meta_data, ftyped).PrintFinal(node);
   return doc.str();
 }
 
-
 TVM_REGISTER_GLOBAL("ir.PrettyPrint")
 .set_body_typed(PrettyPrint);