[Refactor][std::string --> String] IR is updated with String (#5547)
authorANSHUMAN TRIPATHY <anshuman.t@huawei.com>
Mon, 11 May 2020 19:08:04 +0000 (00:38 +0530)
committerGitHub <noreply@github.com>
Mon, 11 May 2020 19:08:04 +0000 (12:08 -0700)
* [std::string --> String] GlobalTypeVar is updated with String

* [std::string --> String] GlobalVar is updated with String

* [std::string --> String][IR] ADT is updated with String

* [std::string --> String][IR] OP is updated with String

* [std::string --> String][IR] Attrs is updated with String input

* [std::string --> String][IR] GlobalVar is updated with String

* [std::string --> String][Test] Pyconverter is updated with String change

17 files changed:
include/tvm/ir/adt.h
include/tvm/ir/env_func.h
include/tvm/ir/expr.h
include/tvm/ir/op.h
include/tvm/ir/transform.h
include/tvm/ir/type.h
include/tvm/runtime/container.h
python/tvm/ir/json_compact.py
python/tvm/relay/testing/py_converter.py
src/ir/adt.cc
src/ir/env_func.cc
src/ir/expr.cc
src/ir/function.cc
src/ir/op.cc
src/ir/transform.cc
src/ir/type.cc
src/printer/relay_text_printer.cc

index 9d45dc1..9b45c66 100644 (file)
@@ -91,7 +91,7 @@ class Constructor : public RelayExpr {
    * \param inputs The input types.
    * \param belong_to The data type var the constructor will construct.
    */
-  TVM_DLL Constructor(std::string name_hint, Array<Type> inputs, GlobalTypeVar belong_to);
+  TVM_DLL Constructor(String name_hint, Array<Type> inputs, GlobalTypeVar belong_to);
 
   TVM_DEFINE_OBJECT_REF_METHODS(Constructor, RelayExpr, ConstructorNode);
 };
index 320d6e3..2f80367 100644 (file)
@@ -92,7 +92,7 @@ class EnvFunc : public ObjectRef {
    * \return The created global function.
    * \note The function can be unique
    */
-  TVM_DLL static EnvFunc Get(const std::string& name);
+  TVM_DLL static EnvFunc Get(const String& name);
   /*! \brief specify container node */
   using ContainerType = EnvFuncNode;
 };
index 717ffb1..6797f16 100644 (file)
@@ -188,7 +188,7 @@ class GlobalVar;
 class GlobalVarNode : public RelayExprNode {
  public:
   /*! \brief The name of the variable, this only acts as a hint. */
-  std::string name_hint;
+  String name_hint;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("name_hint", &name_hint);
@@ -216,7 +216,7 @@ class GlobalVarNode : public RelayExprNode {
  */
 class GlobalVar : public RelayExpr {
  public:
-  TVM_DLL explicit GlobalVar(std::string name_hint);
+  TVM_DLL explicit GlobalVar(String name_hint);
 
   TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
 };
index 7fafb5a..aeda4fa 100644 (file)
@@ -185,7 +185,7 @@ class Op : public RelayExpr {
    * \param op_name Name of the operator.
    * \return Pointer to a Op, valid throughout program lifetime.
    */
-  TVM_DLL static const Op& Get(const std::string& op_name);
+  TVM_DLL static const Op& Get(const String& op_name);
 
   /*! \brief specify container node */
   using ContainerType = OpNode;
@@ -196,13 +196,13 @@ class Op : public RelayExpr {
    * \param key The attribute key
    * \return reference to GenericOpMap
    */
-  TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key);
+  TVM_DLL static const GenericOpMap& GetGenericAttr(const String& key);
   /*!
    * \brief Checks if the key is present in the registry
    * \param key The attribute key
    * \return bool True if the key is present
    */
-  TVM_DLL static bool HasGenericAttr(const std::string& key);
+  TVM_DLL static bool HasGenericAttr(const String& key);
 };
 
 /*!
@@ -303,7 +303,8 @@ class OpRegistry {
   // return internal pointer to op.
   inline OpNode* get();
   // update the attribute OpMap
-  TVM_DLL void UpdateAttr(const std::string& key, runtime::TVMRetValue value, int plevel);
+
+  TVM_DLL void UpdateAttr(const String& key, runtime::TVMRetValue value, int plevel);
 };
 
 /*!
index 558d2da..a825b95 100644 (file)
@@ -224,7 +224,7 @@ class PassInfo : public ObjectRef {
    * \param name Name of the pass.
    * \param required  The passes that are required to perform the current pass.
    */
-  TVM_DLL PassInfo(int opt_level, std::string name, Array<runtime::String> required);
+  TVM_DLL PassInfo(int opt_level, String name, Array<runtime::String> required);
 
   TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
 };
@@ -327,7 +327,7 @@ class Sequential : public Pass {
    *        This allows users to only provide a list of passes and execute them
    *        under a given context.
    */
-  TVM_DLL Sequential(Array<Pass> passes, std::string name = "sequential");
+  TVM_DLL Sequential(Array<Pass> passes, String name = "sequential");
 
   Sequential() = default;
   explicit Sequential(ObjectPtr<Object> n) : Pass(n) {}
@@ -348,7 +348,7 @@ class Sequential : public Pass {
  */
 TVM_DLL Pass
 CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
-                 int opt_level, const std::string& name, const Array<runtime::String>& required);
+                 int opt_level, const String& name, const Array<runtime::String>& required);
 
 /*!
  * \brief A special trace pass that prints the header and IR to LOG(INFO).
@@ -356,7 +356,7 @@ CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>
  * \param show_meta_data Whether should we show meta data.
  * \return The pass.
  */
-TVM_DLL Pass PrintIR(std::string header = "", bool show_meta_data = false);
+TVM_DLL Pass PrintIR(String header = "", bool show_meta_data = false);
 
 }  // namespace transform
 }  // namespace tvm
index ed64841..65b454f 100644 (file)
@@ -267,7 +267,7 @@ class GlobalTypeVarNode : public TypeNode {
    *  this only acts as a hint to the user,
    *  and is not used for equality.
    */
-  std::string name_hint;
+  String name_hint;
   /*! \brief The kind of type parameter */
   TypeKind kind;
 
@@ -301,7 +301,7 @@ class GlobalTypeVar : public Type {
    * \param name_hint The name of the type var.
    * \param kind The kind of the type var.
    */
-  TVM_DLL GlobalTypeVar(std::string name_hint, TypeKind kind);
+  TVM_DLL GlobalTypeVar(String name_hint, TypeKind kind);
 
   TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode);
 };
index 49c005e..e2f2453 100644 (file)
@@ -564,6 +564,15 @@ inline String String::operator=(std::string other) {
   return Downcast<String>(*this);
 }
 
+inline String operator+(const std::string lhs, const String& rhs) {
+  return lhs + rhs.operator std::string();
+}
+
+inline std::ostream& operator<<(std::ostream& out, const String& input) {
+  out.write(input.data(), input.size());
+  return out;
+}
+
 inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) {
   if (lhs == rhs && lhs_count == rhs_count) return 0;
 
index fcea9d8..a3ff499 100644 (file)
@@ -111,7 +111,7 @@ def create_updater_06_to_07():
         "EnvFunc": _update_global_key,
         "relay.Op": _update_global_key,
         "relay.TypeVar": [_ftype_var, _update_from_std_str("name_hint")],
-        "relay.GlobalTypeVar": _ftype_var,
+        "relay.GlobalTypeVar": [_ftype_var, _update_from_std_str("name_hint")],
         "relay.Type": _rename("Type"),
         "relay.TupleType": _rename("TupleType"),
         "relay.TypeConstraint": _rename("TypeConstraint"),
@@ -122,7 +122,7 @@ def create_updater_06_to_07():
         "relay.Module": _rename("IRModule"),
         "relay.SourceName": _rename("SourceName"),
         "relay.Span": _rename("Span"),
-        "relay.GlobalVar": _rename("GlobalVar"),
+        "relay.GlobalVar": [_rename("GlobalVar"), _update_from_std_str("name_hint")],
         "relay.Pass": _rename("transform.Pass"),
         "relay.PassInfo": _rename("transform.PassInfo"),
         "relay.PassContext": _rename("transform.PassContext"),
index 61a04ec..89c3393 100644 (file)
@@ -190,7 +190,7 @@ class PythonConverter(ExprFunctor):
         if name_var is None:
             func_name = self.generate_function_name('_anon_func')
         if isinstance(name_var, GlobalVar):
-            func_name = name_var.name_hint
+            func_name = str(name_var.name_hint)
         if isinstance(name_var, Var):
             func_name = self.get_var_name(name_var)
 
@@ -411,7 +411,7 @@ class PythonConverter(ExprFunctor):
     def visit_global_var(self, gvar: Expr):
         # we don't need to add numbers to global var names because
         # the *names* are checked for uniqueness in the mod
-        return (Name(gvar.name_hint, Load()), [])
+        return (Name(str(gvar.name_hint), Load()), [])
 
 
     def visit_let(self, letexp: Expr):
index 957905d..f0ce859 100644 (file)
@@ -26,7 +26,7 @@
 
 namespace tvm {
 
-Constructor::Constructor(std::string name_hint, tvm::Array<Type> inputs, GlobalTypeVar belong_to) {
+Constructor::Constructor(String name_hint, tvm::Array<Type> inputs, GlobalTypeVar belong_to) {
   ObjectPtr<ConstructorNode> n = make_object<ConstructorNode>();
   n->name_hint = std::move(name_hint);
   n->inputs = std::move(inputs);
@@ -37,7 +37,7 @@ Constructor::Constructor(std::string name_hint, tvm::Array<Type> inputs, GlobalT
 TVM_REGISTER_NODE_TYPE(ConstructorNode);
 
 TVM_REGISTER_GLOBAL("ir.Constructor")
-    .set_body_typed([](std::string name_hint, tvm::Array<Type> inputs, GlobalTypeVar belong_to) {
+    .set_body_typed([](String name_hint, tvm::Array<Type> inputs, GlobalTypeVar belong_to) {
       return Constructor(name_hint, inputs, belong_to);
     });
 
index 7deff90..7b0d6e6 100644 (file)
@@ -45,7 +45,7 @@ ObjectPtr<Object> CreateEnvNode(const std::string& name) {
   return n;
 }
 
-EnvFunc EnvFunc::Get(const std::string& name) { return EnvFunc(CreateEnvNode(name)); }
+EnvFunc EnvFunc::Get(const String& name) { return EnvFunc(CreateEnvNode(name)); }
 
 TVM_REGISTER_GLOBAL("ir.EnvFuncGet").set_body_typed(EnvFunc::Get);
 
index 000305b..8b2656b 100644 (file)
@@ -137,7 +137,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
     });
 
-GlobalVar::GlobalVar(std::string name_hint) {
+GlobalVar::GlobalVar(String name_hint) {
   ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
   n->name_hint = std::move(name_hint);
   data_ = std::move(n);
@@ -145,7 +145,7 @@ GlobalVar::GlobalVar(std::string name_hint) {
 
 TVM_REGISTER_NODE_TYPE(GlobalVarNode);
 
-TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](std::string name) {
+TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name) {
   return GlobalVar(name);
 });
 
index 57d62b4..c0cda70 100644 (file)
@@ -38,7 +38,7 @@ TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs").set_body_typed([](BaseFunc func) { retu
 TVM_REGISTER_GLOBAL("ir.BaseFuncCopy").set_body_typed([](BaseFunc func) { return func; });
 
 TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr")
-    .set_body_typed([](BaseFunc func, std::string key, ObjectRef value) -> BaseFunc {
+    .set_body_typed([](BaseFunc func, String key, ObjectRef value) -> BaseFunc {
       if (func->IsInstance<tir::PrimFuncNode>()) {
         return WithAttr(Downcast<tir::PrimFunc>(std::move(func)), key, value);
       } else if (func->IsInstance<relay::FunctionNode>()) {
index 8f58768..3a6bcbc 100644 (file)
@@ -61,7 +61,7 @@ struct OpManager {
 };
 
 // find operator by name
-const Op& Op::Get(const std::string& name) {
+const Op& Op::Get(const String& name) {
   const OpRegistry* reg = dmlc::Registry<OpRegistry>::Find(name);
   CHECK(reg != nullptr) << "Operator " << name << " is not registered";
   return reg->op();
@@ -75,7 +75,7 @@ OpRegistry::OpRegistry() {
 }
 
 // Get attribute map by key
-const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
+const GenericOpMap& Op::GetGenericAttr(const String& key) {
   OpManager* mgr = OpManager::Global();
   std::lock_guard<std::mutex> lock(mgr->mutex);
   auto it = mgr->attr.find(key);
@@ -86,7 +86,7 @@ const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
 }
 
 // Check if a key is present in the registry.
-bool Op::HasGenericAttr(const std::string& key) {
+bool Op::HasGenericAttr(const String& key) {
   OpManager* mgr = OpManager::Global();
   std::lock_guard<std::mutex> lock(mgr->mutex);
   auto it = mgr->attr.find(key);
@@ -110,7 +110,7 @@ void OpRegistry::reset_attr(const std::string& key) {
   }
 }
 
-void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value, int plevel) {
+void OpRegistry::UpdateAttr(const String& key, TVMRetValue value, int plevel) {
   OpManager* mgr = OpManager::Global();
   std::lock_guard<std::mutex> lock(mgr->mutex);
   std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
@@ -141,7 +141,7 @@ TVM_REGISTER_GLOBAL("relay.op._ListOpNames").set_body_typed([]() {
   return ret;
 });
 
-TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](std::string name) -> Op {
+TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](String name) -> Op {
   return Op::Get(name);
 });
 
index d7d9b06..59e0c1c 100644 (file)
@@ -201,7 +201,7 @@ class SequentialNode : public PassNode {
   TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
 };
 
-PassInfo::PassInfo(int opt_level, std::string name, tvm::Array<runtime::String> required) {
+PassInfo::PassInfo(int opt_level, String name, tvm::Array<runtime::String> required) {
   auto pass_info = make_object<PassInfoNode>();
   pass_info->opt_level = opt_level;
   pass_info->name = std::move(name);
@@ -238,7 +238,7 @@ Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) {
   data_ = std::move(n);
 }
 
-Sequential::Sequential(tvm::Array<Pass> passes, std::string name) {
+Sequential::Sequential(tvm::Array<Pass> passes, String name) {
   auto n = make_object<SequentialNode>();
   n->passes = std::move(passes);
   PassInfo pass_info = PassInfo(2, std::move(name), {});
@@ -282,10 +282,10 @@ bool SequentialNode::PassEnabled(const PassInfo& info) const {
   return ctx->opt_level >= info->opt_level;
 }
 
-Pass GetPass(const std::string& pass_name) {
+Pass GetPass(const String& pass_name) {
   using tvm::runtime::Registry;
   const runtime::PackedFunc* f = nullptr;
-  if (pass_name.find("transform.") != std::string::npos) {
+  if (pass_name.operator std::string().find("transform.") != std::string::npos) {
     f = Registry::Get(pass_name);
   } else if ((f = Registry::Get("transform." + pass_name))) {
     // pass
@@ -313,7 +313,7 @@ IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) c
 }
 
 Pass CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
-                      int opt_level, const std::string& name,
+                      int opt_level, const String& name,
                       const tvm::Array<runtime::String>& required) {
   PassInfo pass_info = PassInfo(opt_level, name, required);
   return ModulePass(pass_func, pass_info);
@@ -322,7 +322,7 @@ Pass CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassCont
 TVM_REGISTER_NODE_TYPE(PassInfoNode);
 
 TVM_REGISTER_GLOBAL("transform.PassInfo")
-    .set_body_typed([](int opt_level, std::string name, tvm::Array<runtime::String> required) {
+    .set_body_typed([](int opt_level, String name, tvm::Array<runtime::String> required) {
       return PassInfo(opt_level, name, required);
     });
 
@@ -439,7 +439,7 @@ TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::In
 
 TVM_REGISTER_GLOBAL("transform.ExitPassContext").set_body_typed(PassContext::Internal::ExitScope);
 
-Pass PrintIR(std::string header, bool show_meta_data) {
+Pass PrintIR(String header, bool show_meta_data) {
   auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) {
     LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_meta_data);
     return mod;
index 212a6e5..38a6ec3 100644 (file)
@@ -81,7 +81,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << "TypeVar(" << node->name_hint << ", " << node->kind << ")";
     });
 
-GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) {
+GlobalTypeVar::GlobalTypeVar(String name, TypeKind kind) {
   ObjectPtr<GlobalTypeVarNode> n = make_object<GlobalTypeVarNode>();
   n->name_hint = std::move(name);
   n->kind = std::move(kind);
@@ -90,7 +90,7 @@ GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) {
 
 TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
 
-TVM_REGISTER_GLOBAL("ir.GlobalTypeVar").set_body_typed([](std::string name, int kind) {
+TVM_REGISTER_GLOBAL("ir.GlobalTypeVar").set_body_typed([](String name, int kind) {
   return GlobalTypeVar(name, static_cast<TypeKind>(kind));
 });
 
index 3c545ef..5166a48 100644 (file)
@@ -446,7 +446,9 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) {
   return PrintFunc(Doc::Text("fn "), GetRef<Function>(op));
 }
 
-Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { return Doc::Text('@' + op->name_hint); }
+Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) {
+  return Doc::Text('@' + op->name_hint.operator std::string());
+}
 
 Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); }