#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
#include <tvm/ir/adt.h>
-
+#include <tvm/node/container.h>
#include <string>
#include <vector>
#include <unordered_map>
* \param name The variable name.
* \returns true if contains, otherise false.
*/
- TVM_DLL bool ContainGlobalVar(const std::string& name) const;
+ TVM_DLL bool ContainGlobalVar(const String& name) const;
/*!
* \brief Check if the global_type_var_map_ contains a global type variable.
* \param name The variable name.
* \returns true if contains, otherise false.
*/
- TVM_DLL bool ContainGlobalTypeVar(const std::string& name) const;
+ TVM_DLL bool ContainGlobalTypeVar(const String& name) const;
/*!
* \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
- TVM_DLL GlobalVar GetGlobalVar(const std::string& str) const;
+ TVM_DLL GlobalVar GetGlobalVar(const String& str) const;
/*!
* \brief Collect all global vars defined in this module.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
- TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const;
+ TVM_DLL GlobalTypeVar GetGlobalTypeVar(const String& str) const;
/*!
* \brief Collect all global type vars defined in this module.
* \param cons name of the constructor
* \returns Constructor of ADT, error if not found
*/
- TVM_DLL Constructor GetConstructor(const std::string& adt, const std::string& cons) const;
+ TVM_DLL Constructor GetConstructor(const String& adt, const String& cons) const;
/*!
* \brief Look up a global function by its variable.
* \param name The name of the function.
* \returns The function named by the argument.
*/
- TVM_DLL BaseFunc Lookup(const std::string& name) const;
+ TVM_DLL BaseFunc Lookup(const String& name) const;
/*!
* \brief Look up a global type definition by its variable.
* \param var The name of the global type definition.
* \return The type definition.
*/
- TVM_DLL TypeData LookupTypeDef(const std::string& var) const;
+ TVM_DLL TypeData LookupTypeDef(const String& var) const;
/*!
* \brief Look up a constructor by its tag.
* relative it will be resovled against the current
* working directory.
*/
- TVM_DLL void Import(const std::string& path);
+ TVM_DLL void Import(const String& path);
/*!
* \brief Import Relay code from the file at path, relative to the standard library.
* \param path The path of the Relay code to import.
*/
- TVM_DLL void ImportFromStd(const std::string& path);
+ TVM_DLL void ImportFromStd(const String& path);
/*!
* \brief The set of imported files.
*/
- TVM_DLL std::unordered_set<std::string> Imports() const;
+ TVM_DLL std::unordered_set<String> Imports() const;
static constexpr const char* _type_key = "IRModule";
static constexpr const bool _type_has_method_sequal_reduce = true;
/*! \brief The files previously imported, required to ensure
importing is idempotent for each module.
*/
- std::unordered_set<std::string> import_set_;
+ std::unordered_set<String> import_set_;
friend class IRModule;
};
*/
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
Map<GlobalTypeVar, TypeData> type_definitions = {},
- std::unordered_set<std::string> import_set = {});
+ std::unordered_set<String> import_set = {});
/*! \brief default constructor */
IRModule() {}
/*!
* \param source_path The path to the source file.
* \return A Relay module.
*/
- TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path);
+ TVM_DLL static IRModule FromText(const String& text, const String& source_path);
/*! \brief Declare the container type. */
using ContainerType = IRModuleNode;
* Use AsText if you want to store the text.
* \sa AsText.
*/
-TVM_DLL std::string PrettyPrint(const ObjectRef& node);
+TVM_DLL String PrettyPrint(const ObjectRef& node);
/*!
* \brief Render the node as a string in the text format.
* \sa PrettyPrint.
* \return The text representation.
*/
-TVM_DLL std::string AsText(const ObjectRef& node,
+TVM_DLL String AsText(const ObjectRef& node,
bool show_meta_data = true,
- runtime::TypedPackedFunc<std::string(ObjectRef)> annotate = nullptr);
+ runtime::TypedPackedFunc<String(ObjectRef)> annotate = nullptr);
} // namespace tvm
#endif // TVM_IR_MODULE_H_
IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
tvm::Map<GlobalTypeVar, TypeData> type_definitions,
- std::unordered_set<std::string> import_set) {
+ std::unordered_set<String> import_set) {
auto n = make_object<IRModuleNode>();
n->functions = std::move(functions);
n->type_definitions = std::move(type_definitions);
reduce_temp();
}
-bool IRModuleNode::ContainGlobalVar(const std::string& name) const {
+bool IRModuleNode::ContainGlobalVar(const String& name) const {
return global_var_map_.find(name) != global_var_map_.end();
}
-bool IRModuleNode::ContainGlobalTypeVar(const std::string& name) const {
+bool IRModuleNode::ContainGlobalTypeVar(const String& name) const {
return global_type_var_map_.find(name) != global_type_var_map_.end();
}
-GlobalVar IRModuleNode::GetGlobalVar(const std::string& name) const {
+GlobalVar IRModuleNode::GetGlobalVar(const String& name) const {
auto it = global_var_map_.find(name);
if (it == global_var_map_.end()) {
std::ostringstream msg;
return tvm::Array<GlobalVar>(global_vars);
}
-GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const {
+GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const String& name) const {
CHECK(global_type_var_map_.defined());
auto it = global_type_var_map_.find(name);
CHECK(it != global_type_var_map_.end())
return (*it).second;
}
-Constructor IRModuleNode::GetConstructor(const std::string& adt, const std::string& cons) const {
+Constructor IRModuleNode::GetConstructor(const String& adt, const String& cons) const {
TypeData typeDef = this->LookupTypeDef(adt);
for (Constructor c : typeDef->constructors) {
if (cons.compare(c->name_hint) == 0) {
return (*it).second;
}
-BaseFunc IRModuleNode::Lookup(const std::string& name) const {
+BaseFunc IRModuleNode::Lookup(const String& name) const {
GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id);
}
return (*it).second;
}
-TypeData IRModuleNode::LookupTypeDef(const std::string& name) const {
+TypeData IRModuleNode::LookupTypeDef(const String& name) const {
GlobalTypeVar id = this->GetGlobalTypeVar(name);
return this->LookupTypeDef(id);
}
return mod;
}
-void IRModuleNode::Import(const std::string& path) {
+void IRModuleNode::Import(const String& path) {
if (this->import_set_.count(path) == 0) {
this->import_set_.insert(path);
DLOG(INFO) << "Importing: " << path;
}
}
-void IRModuleNode::ImportFromStd(const std::string& path) {
+void IRModuleNode::ImportFromStd(const String& path) {
auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path");
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
std::string std_path = (*f)();
- return this->Import(std_path + "/" + path);
+ this->Import(std_path + "/" + path.operator std::string());
}
-std::unordered_set<std::string> IRModuleNode::Imports() const {
+std::unordered_set<String> IRModuleNode::Imports() const {
return this->import_set_;
}
-IRModule IRModule::FromText(const std::string& text, const std::string& source_path) {
+IRModule IRModule::FromText(const String& text, const String& source_path) {
auto* f = tvm::runtime::Registry::Get("relay.fromtext");
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
IRModule mod = (*f)(text, source_path);
});
TVM_REGISTER_GLOBAL("ir.Module_Lookup_str")
-.set_body_typed([](IRModule mod, std::string var) {
+.set_body_typed([](IRModule mod, String var) {
return mod->Lookup(var);
});
});
TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str")
-.set_body_typed([](IRModule mod, std::string var) {
+.set_body_typed([](IRModule mod, String var) {
return mod->LookupTypeDef(var);
});
});
TVM_REGISTER_GLOBAL("ir.Module_Import")
-.set_body_typed([](IRModule mod, std::string path) {
+.set_body_typed([](IRModule mod, String path) {
mod->Import(path);
});
TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd")
-.set_body_typed([](IRModule mod, std::string path) {
+.set_body_typed([](IRModule mod, String path) {
mod->ImportFromStd(path);
});;
// - Implements AsText
// - relay_text_printer.cc (specific printing logics for relay)
// - tir_text_printer.cc (specific printing logics for TIR)
-std::string PrettyPrint(const ObjectRef& node) {
+String PrettyPrint(const ObjectRef& node) {
Doc doc;
doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node);
return doc.str();
}
-std::string AsText(const ObjectRef& node,
+String AsText(const ObjectRef& node,
bool show_meta_data,
- runtime::TypedPackedFunc<std::string(ObjectRef)> annotate) {
+ runtime::TypedPackedFunc<String(ObjectRef)> annotate) {
Doc doc;
doc << kSemVer << Doc::NewLine();
- doc << relay::RelayTextPrinter(show_meta_data, annotate).PrintFinal(node);
+ runtime::TypedPackedFunc<std::string(ObjectRef)> ftyped = nullptr;
+ if (annotate != nullptr) {
+ ftyped = runtime::TypedPackedFunc<std::string(ObjectRef)>(
+ [&annotate](const ObjectRef& expr) -> std::string {
+ return annotate(expr);
+ });
+ }
+ doc << relay::RelayTextPrinter(show_meta_data, ftyped).PrintFinal(node);
return doc.str();
}
-
TVM_REGISTER_GLOBAL("ir.PrettyPrint")
.set_body_typed(PrettyPrint);