* \param inputs The input types.
* \param belong_to The data type var the constructor will construct.
*/
- TVM_DLL Constructor(std::string name_hint, Array<Type> inputs, GlobalTypeVar belong_to);
+ TVM_DLL Constructor(String name_hint, Array<Type> inputs, GlobalTypeVar belong_to);
TVM_DEFINE_OBJECT_REF_METHODS(Constructor, RelayExpr, ConstructorNode);
};
* \return The created global function.
* \note The function can be unique
*/
- TVM_DLL static EnvFunc Get(const std::string& name);
+ TVM_DLL static EnvFunc Get(const String& name);
/*! \brief specify container node */
using ContainerType = EnvFuncNode;
};
class GlobalVarNode : public RelayExprNode {
public:
/*! \brief The name of the variable, this only acts as a hint. */
- std::string name_hint;
+ String name_hint;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
*/
class GlobalVar : public RelayExpr {
public:
- TVM_DLL explicit GlobalVar(std::string name_hint);
+ TVM_DLL explicit GlobalVar(String name_hint);
TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
};
* \param op_name Name of the operator.
* \return Pointer to a Op, valid throughout program lifetime.
*/
- TVM_DLL static const Op& Get(const std::string& op_name);
+ TVM_DLL static const Op& Get(const String& op_name);
/*! \brief specify container node */
using ContainerType = OpNode;
* \param key The attribute key
* \return reference to GenericOpMap
*/
- TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key);
+ TVM_DLL static const GenericOpMap& GetGenericAttr(const String& key);
/*!
* \brief Checks if the key is present in the registry
* \param key The attribute key
* \return bool True if the key is present
*/
- TVM_DLL static bool HasGenericAttr(const std::string& key);
+ TVM_DLL static bool HasGenericAttr(const String& key);
};
/*!
// return internal pointer to op.
inline OpNode* get();
// update the attribute OpMap
- TVM_DLL void UpdateAttr(const std::string& key, runtime::TVMRetValue value, int plevel);
+
+ TVM_DLL void UpdateAttr(const String& key, runtime::TVMRetValue value, int plevel);
};
/*!
* \param name Name of the pass.
* \param required The passes that are required to perform the current pass.
*/
- TVM_DLL PassInfo(int opt_level, std::string name, Array<runtime::String> required);
+ TVM_DLL PassInfo(int opt_level, String name, Array<runtime::String> required);
TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
};
* This allows users to only provide a list of passes and execute them
* under a given context.
*/
- TVM_DLL Sequential(Array<Pass> passes, std::string name = "sequential");
+ TVM_DLL Sequential(Array<Pass> passes, String name = "sequential");
Sequential() = default;
explicit Sequential(ObjectPtr<Object> n) : Pass(n) {}
*/
TVM_DLL Pass
CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
- int opt_level, const std::string& name, const Array<runtime::String>& required);
+ int opt_level, const String& name, const Array<runtime::String>& required);
/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
* \param show_meta_data Whether should we show meta data.
* \return The pass.
*/
-TVM_DLL Pass PrintIR(std::string header = "", bool show_meta_data = false);
+TVM_DLL Pass PrintIR(String header = "", bool show_meta_data = false);
} // namespace transform
} // namespace tvm
* this only acts as a hint to the user,
* and is not used for equality.
*/
- std::string name_hint;
+ String name_hint;
/*! \brief The kind of type parameter */
TypeKind kind;
* \param name_hint The name of the type var.
* \param kind The kind of the type var.
*/
- TVM_DLL GlobalTypeVar(std::string name_hint, TypeKind kind);
+ TVM_DLL GlobalTypeVar(String name_hint, TypeKind kind);
TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode);
};
return Downcast<String>(*this);
}
+inline String operator+(const std::string lhs, const String& rhs) {
+ return lhs + rhs.operator std::string();
+}
+
+inline std::ostream& operator<<(std::ostream& out, const String& input) {
+ out.write(input.data(), input.size());
+ return out;
+}
+
inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) {
if (lhs == rhs && lhs_count == rhs_count) return 0;
"EnvFunc": _update_global_key,
"relay.Op": _update_global_key,
"relay.TypeVar": [_ftype_var, _update_from_std_str("name_hint")],
- "relay.GlobalTypeVar": _ftype_var,
+ "relay.GlobalTypeVar": [_ftype_var, _update_from_std_str("name_hint")],
"relay.Type": _rename("Type"),
"relay.TupleType": _rename("TupleType"),
"relay.TypeConstraint": _rename("TypeConstraint"),
"relay.Module": _rename("IRModule"),
"relay.SourceName": _rename("SourceName"),
"relay.Span": _rename("Span"),
- "relay.GlobalVar": _rename("GlobalVar"),
+ "relay.GlobalVar": [_rename("GlobalVar"), _update_from_std_str("name_hint")],
"relay.Pass": _rename("transform.Pass"),
"relay.PassInfo": _rename("transform.PassInfo"),
"relay.PassContext": _rename("transform.PassContext"),
if name_var is None:
func_name = self.generate_function_name('_anon_func')
if isinstance(name_var, GlobalVar):
- func_name = name_var.name_hint
+ func_name = str(name_var.name_hint)
if isinstance(name_var, Var):
func_name = self.get_var_name(name_var)
def visit_global_var(self, gvar: Expr):
# we don't need to add numbers to global var names because
# the *names* are checked for uniqueness in the mod
- return (Name(gvar.name_hint, Load()), [])
+ return (Name(str(gvar.name_hint), Load()), [])
def visit_let(self, letexp: Expr):
namespace tvm {
-Constructor::Constructor(std::string name_hint, tvm::Array<Type> inputs, GlobalTypeVar belong_to) {
+Constructor::Constructor(String name_hint, tvm::Array<Type> inputs, GlobalTypeVar belong_to) {
ObjectPtr<ConstructorNode> n = make_object<ConstructorNode>();
n->name_hint = std::move(name_hint);
n->inputs = std::move(inputs);
TVM_REGISTER_NODE_TYPE(ConstructorNode);
TVM_REGISTER_GLOBAL("ir.Constructor")
- .set_body_typed([](std::string name_hint, tvm::Array<Type> inputs, GlobalTypeVar belong_to) {
+ .set_body_typed([](String name_hint, tvm::Array<Type> inputs, GlobalTypeVar belong_to) {
return Constructor(name_hint, inputs, belong_to);
});
return n;
}
-EnvFunc EnvFunc::Get(const std::string& name) { return EnvFunc(CreateEnvNode(name)); }
+EnvFunc EnvFunc::Get(const String& name) { return EnvFunc(CreateEnvNode(name)); }
TVM_REGISTER_GLOBAL("ir.EnvFuncGet").set_body_typed(EnvFunc::Get);
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});
-GlobalVar::GlobalVar(std::string name_hint) {
+GlobalVar::GlobalVar(String name_hint) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
n->name_hint = std::move(name_hint);
data_ = std::move(n);
TVM_REGISTER_NODE_TYPE(GlobalVarNode);
-TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](std::string name) {
+TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name) {
return GlobalVar(name);
});
TVM_REGISTER_GLOBAL("ir.BaseFuncCopy").set_body_typed([](BaseFunc func) { return func; });
TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr")
- .set_body_typed([](BaseFunc func, std::string key, ObjectRef value) -> BaseFunc {
+ .set_body_typed([](BaseFunc func, String key, ObjectRef value) -> BaseFunc {
if (func->IsInstance<tir::PrimFuncNode>()) {
return WithAttr(Downcast<tir::PrimFunc>(std::move(func)), key, value);
} else if (func->IsInstance<relay::FunctionNode>()) {
};
// find operator by name
-const Op& Op::Get(const std::string& name) {
+const Op& Op::Get(const String& name) {
const OpRegistry* reg = dmlc::Registry<OpRegistry>::Find(name);
CHECK(reg != nullptr) << "Operator " << name << " is not registered";
return reg->op();
}
// Get attribute map by key
-const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
+const GenericOpMap& Op::GetGenericAttr(const String& key) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
auto it = mgr->attr.find(key);
}
// Check if a key is present in the registry.
-bool Op::HasGenericAttr(const std::string& key) {
+bool Op::HasGenericAttr(const String& key) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
auto it = mgr->attr.find(key);
}
}
-void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value, int plevel) {
+void OpRegistry::UpdateAttr(const String& key, TVMRetValue value, int plevel) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
return ret;
});
-TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](std::string name) -> Op {
+TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](String name) -> Op {
return Op::Get(name);
});
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
};
-PassInfo::PassInfo(int opt_level, std::string name, tvm::Array<runtime::String> required) {
+PassInfo::PassInfo(int opt_level, String name, tvm::Array<runtime::String> required) {
auto pass_info = make_object<PassInfoNode>();
pass_info->opt_level = opt_level;
pass_info->name = std::move(name);
data_ = std::move(n);
}
-Sequential::Sequential(tvm::Array<Pass> passes, std::string name) {
+Sequential::Sequential(tvm::Array<Pass> passes, String name) {
auto n = make_object<SequentialNode>();
n->passes = std::move(passes);
PassInfo pass_info = PassInfo(2, std::move(name), {});
return ctx->opt_level >= info->opt_level;
}
-Pass GetPass(const std::string& pass_name) {
+Pass GetPass(const String& pass_name) {
using tvm::runtime::Registry;
const runtime::PackedFunc* f = nullptr;
- if (pass_name.find("transform.") != std::string::npos) {
+ if (pass_name.operator std::string().find("transform.") != std::string::npos) {
f = Registry::Get(pass_name);
} else if ((f = Registry::Get("transform." + pass_name))) {
// pass
}
Pass CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
- int opt_level, const std::string& name,
+ int opt_level, const String& name,
const tvm::Array<runtime::String>& required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return ModulePass(pass_func, pass_info);
TVM_REGISTER_NODE_TYPE(PassInfoNode);
TVM_REGISTER_GLOBAL("transform.PassInfo")
- .set_body_typed([](int opt_level, std::string name, tvm::Array<runtime::String> required) {
+ .set_body_typed([](int opt_level, String name, tvm::Array<runtime::String> required) {
return PassInfo(opt_level, name, required);
});
TVM_REGISTER_GLOBAL("transform.ExitPassContext").set_body_typed(PassContext::Internal::ExitScope);
-Pass PrintIR(std::string header, bool show_meta_data) {
+Pass PrintIR(String header, bool show_meta_data) {
auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) {
LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_meta_data);
return mod;
p->stream << "TypeVar(" << node->name_hint << ", " << node->kind << ")";
});
-GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) {
+GlobalTypeVar::GlobalTypeVar(String name, TypeKind kind) {
ObjectPtr<GlobalTypeVarNode> n = make_object<GlobalTypeVarNode>();
n->name_hint = std::move(name);
n->kind = std::move(kind);
TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
-TVM_REGISTER_GLOBAL("ir.GlobalTypeVar").set_body_typed([](std::string name, int kind) {
+TVM_REGISTER_GLOBAL("ir.GlobalTypeVar").set_body_typed([](String name, int kind) {
return GlobalTypeVar(name, static_cast<TypeKind>(kind));
});
return PrintFunc(Doc::Text("fn "), GetRef<Function>(op));
}
-Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { return Doc::Text('@' + op->name_hint); }
+Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) {
+ return Doc::Text('@' + op->name_hint.operator std::string());
+}
Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); }