using namespace tvm::relay::transform;
/*!
- * \brief Bind params to function by using name
- * \param func Relay function
- * \param params params dict
- * \return relay::Function
- */
-relay::Function BindParamsByName(relay::Function func,
- const std::unordered_map<std::string, runtime::NDArray>& params) {
- std::unordered_map<std::string, relay::Var> name_dict;
- std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
- for (auto arg : func->params) {
- const auto& name = arg->name_hint();
- if (name_dict.count(name)) {
- repeat_var.insert(arg);
- } else {
- name_dict[name] = arg;
- }
- }
-
- std::unordered_map<relay::Var, Expr, ObjectHash, ObjectEqual> bind_dict;
- for (auto& kv : params) {
- if (name_dict.count(kv.first) == 0) {
- continue;
- }
- auto arg = name_dict.at(kv.first);
- if (repeat_var.count(arg)) {
- LOG(FATAL) << "Multiple args in the function have name " << kv.first;
- }
- bind_dict[arg] = ConstantNode::make(kv.second);
- }
- Expr bound_expr = relay::Bind(func, bind_dict);
- Function ret = Downcast<Function>(bound_expr);
- CHECK(ret.defined()) << "The returning type is expected to be a Relay Function."
- << "\n";
- return ret;
-}
-
-/*!
* \brief Output of building module
*
*/
for (const auto& kv : params) {
params_[kv.first] = kv.second->data;
}
- *rv = BindParamsByName(args[0], params_);
+ *rv = relay::backend::BindParamsByName(args[0], params_);
});
} // namespace backend
#include <dmlc/json.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
+#include <tvm/relay/transform.h>
#include <tvm/driver/driver_api.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/ir_pass.h>
#include <typeinfo>
#include <string>
+#include <unordered_map>
+#include <unordered_set>
namespace tvm {
namespace relay {
return os.str();
}
+/*!
+ * \brief Bind params to function by using name
+ * \param func Relay function
+ * \param params params dict
+ * \return relay::Function
+ */
+inline relay::Function
+BindParamsByName(relay::Function func,
+ const std::unordered_map<std::string, runtime::NDArray>& params) {
+ std::unordered_map<std::string, relay::Var> name_dict;
+ std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
+ for (auto arg : func->params) {
+ const auto& name = arg->name_hint();
+ if (name_dict.count(name)) {
+ repeat_var.insert(arg);
+ } else {
+ name_dict[name] = arg;
+ }
+ }
+
+ std::unordered_map<relay::Var, Expr, ObjectHash, ObjectEqual> bind_dict;
+ for (auto& kv : params) {
+ if (name_dict.count(kv.first) == 0) {
+ continue;
+ }
+ auto arg = name_dict.at(kv.first);
+ if (repeat_var.count(arg)) {
+ LOG(FATAL) << "Multiple args in the function have name " << kv.first;
+ }
+ bind_dict[arg] = ConstantNode::make(kv.second);
+ }
+ Expr bound_expr = relay::Bind(func, bind_dict);
+ Function ret = Downcast<Function>(bound_expr);
+ CHECK(ret.defined()) << "The returning type is expected to be a Relay Function."
+ << "\n";
+ return ret;
+}
+
} // namespace backend
} // namespace relay
} // namespace tvm
#include <memory>
#include <string>
#include <tuple>
-#include <unordered_map>
-#include <unordered_set>
#include <vector>
+#include "../utils.h"
#include "../../backend/compile_engine.h"
#include "../../pass/pass_util.h"
#include "../../op/op_common.h"
params_[name] = data_in;
}
-relay::Function VMCompiler::BindParamsByName(
- relay::Function func,
- const std::unordered_map<std::string, runtime::NDArray>& params) {
- std::unordered_map<std::string, relay::Var> name_dict;
- std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
- for (auto arg : func->params) {
- const auto &name = arg->name_hint();
- if (name_dict.count(name)) {
- repeat_var.insert(arg);
- } else {
- name_dict[name] = arg;
- }
- }
- std::unordered_map<relay::Var, Expr, ObjectHash, ObjectEqual> bind_dict;
- for (auto &kv : params) {
- if (name_dict.count(kv.first) == 0) {
- continue;
- }
- auto arg = name_dict.at(kv.first);
- if (repeat_var.count(arg)) {
- LOG(FATAL) << "Multiple args in the function have name " << kv.first;
- }
- bind_dict[arg] = ConstantNode::make(kv.second);
- }
- Expr bound_expr = relay::Bind(func, bind_dict);
- Function ret = Downcast<Function>(bound_expr);
- CHECK(ret.defined())
- << "The returning type is expected to be a Relay Function."
- << "\n";
- return ret;
-}
-
void VMCompiler::Lower(IRModule mod,
const TargetsMap& targets,
const tvm::Target& target_host) {
BaseFunc base_func = mod->Lookup("main");
CHECK(base_func->IsInstance<FunctionNode>())
<< "VM compiler expects to compile relay::Function";
- auto f = BindParamsByName(Downcast<Function>(base_func), params_);
+ auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params_);
auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f);
}
void Codegen();
protected:
- /*!
- * \brief Bind params to function by using name
- * \param func Relay function
- * \param params params dict
- * \return relay::Function
- */
- relay::Function BindParamsByName(
- relay::Function func,
- const std::unordered_map<std::string, runtime::NDArray>& params);
-
IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets);
void PopulateGlobalMap();