*/
/*!
- * Copyright (c) 2019 by Contributors
* \file relay/backend/build_module.cc
* \brief Code generation for TVM's graph runtime.
*/
-
#include <tvm/build_module.h>
+#include <tvm/runtime/device_api.h>
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/nn.h>
using TargetsMap = Map<tvm::Integer, tvm::Target>;
/*!
- * \brief Context index to Target
- */
-struct ContextTargetMap {
- static const std::unordered_map<int, tvm::Target> mask2str;
- static tvm::Target Mask2Str(int mask) {
- CHECK_GT(mask2str.count(mask), 0) << "Unknown mask.";
- return mask2str.at(mask);
- }
-};
-
-const std::unordered_map<int, tvm::Target> ContextTargetMap::mask2str = {
- {1, tvm::Target::create("llvm")},
- {2, tvm::Target::create("cuda")},
- {4, tvm::Target::create("opencl")},
- {5, tvm::Target::create("aocl")},
- {6, tvm::Target::create("sdaccel")},
- {7, tvm::Target::create("vulkan")},
- {8, tvm::Target::create("metal")},
- {9, tvm::Target::create("vpi")},
- {10, tvm::Target::create("rocm")},
- {11, tvm::Target::create("opengl")},
- {12, tvm::Target::create("ext_dev")}
-};
-
-/*!
* \brief A data structure to map the names of specific optimizations to
* numeric optimization levels
*
*
* \return Array<StringImm> names of params
*/
- Array<HalideIR::Expr> ListParamNames() {
- Array<HalideIR::Expr> ret;
+ Array<tvm::Expr> ListParamNames() {
+ Array<tvm::Expr> ret;
for (const auto& kv : params_) {
ret.push_back(ir::StringImm::make(kv.first));
}
if (cfg.pass_enabled("AlterOpLayout")) {
if (targets.size() == 1) {
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
- auto enter_pf = GetPackedFunc("_EnterTargetScope");
- auto exit_pf = GetPackedFunc("_ExitTargetScope");
for (const auto& kv : targets) {
- (*enter_pf)(kv.second);
+ TargetContext tctx(kv.second);
func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func);
- (*exit_pf)();
}
} else {
LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous"
}
return func;
}
+
+ /*!
+ * \brief Create a default type.
+ * \param device_type The device type index.
+ * \return the default target for the device.
+ */
+ Target CreateDefaultTarget(int device_type) {
+ std::string name = runtime::DeviceName(device_type);
+ if (name == "cpu") return Target::create("llvm");
+ if (name == "gpu") return Target::create("cuda");
+ return Target::create(name);
+ }
/*!
* \brief Update the target and fallback device required for heterogeneous
* compilation. CPU is used as the fallback device if it wasn't provided.
if (tmp_map.count(cfg.fallback_device) == 0) {
device_target.Set(
cfg.fallback_device,
- ContextTargetMap::Mask2Str(cfg.fallback_device));
+ CreateDefaultTarget(cfg.fallback_device));
}
return device_target;
}
* \param targets_map_ptr
* \return Function
*/
- Function RunDeviceAnnotationPass(Function func, const RelayBuildConfig& cfg,
+ Function RunDeviceAnnotationPass(Function func,
+ const RelayBuildConfig& cfg,
TargetsMap* targets_map_ptr) {
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func,
"relay._ir_pass.CollectDeviceAnnotationOps", func, nullptr);
if (annotation_map.size() == 0) {
targets_map_ptr->Set(
- 0, ContextTargetMap::Mask2Str(cfg.fallback_device));
+ 0, CreateDefaultTarget(cfg.fallback_device));
} else {
int64_t dev_type = -1;
for (auto kv : annotation_map) {
<< "found. Please check the "
<< "RewriteAnnotation pass.";
}
- targets_map_ptr->Set(0, ContextTargetMap::Mask2Str(dev_type));
+ targets_map_ptr->Set(0, CreateDefaultTarget(dev_type));
}
}
return func;
return runtime::Module(exec);
}
-TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("relay.build_module._BuildModule")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RelayBuildCreate();
});