* [Target][Codegen] Make all code generator use Target class instead of target string
* Remove dep to TargetNode::str() in LLVM module
* Allow for llvm nvptx codegen
* ...
* Address comments from Cody
* Rename UpdateTargetConfig => UpdateTargetConfigKeyValueEntry
* \param target The target to be built.
* \return The result runtime::Module.
*/
-runtime::Module Build(IRModule mod, const Target& target);
+runtime::Module Build(IRModule mod, Target target);
/*!
* \brief Pack imported device library to a C file.
Array<String> keys;
/*! \brief Collection of attributes */
Map<String, ObjectRef> attrs;
-
/*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const;
+ /*! \return Export target to JSON-like configuration */
+ TVM_DLL Map<String, ObjectRef> Export() const;
void VisitAttrs(AttrVisitor* v) {
v->Visit("kind", &kind);
def __exit__(self, ptype, value, trace):
_ffi_api.ExitTargetScope(self)
+ def export(self):
+ return _ffi_api.TargetExport(self)
+
@staticmethod
def current(allow_none=True):
"""Returns the current target.
}
return fmap;
}
+
+inline void UpdateTargetConfigKeyValueEntry(const String& key, const String& value,
+ Map<String, ObjectRef>* target_config,
+ bool error_if_inconsistent) {
+ if (target_config->count(key)) {
+ const ObjectRef& obj = (*target_config)[key];
+ CHECK(obj->IsInstance<StringObj>()) << "TypeError: Expect target key \"" << key
+ << "\" to be String, but gets type: " << obj->GetTypeKey();
+ if (error_if_inconsistent) {
+ String old_value = Downcast<String>(obj);
+ CHECK_EQ(old_value, value) << "ValueError: Target key \"" << key << "\" has been set to \""
+ << old_value << "\", and cannot be reset to \"" << value << "\"";
+ }
+ }
+ target_config->Set(key, value);
+}
+
} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_BUILD_COMMON_H_
namespace tvm {
namespace codegen {
-runtime::Module Build(IRModule mod, const Target& target) {
+runtime::Module Build(IRModule mod, Target target) {
if (transform::PassContext::Current()
->GetConfig<Bool>("tir.disable_assert", Bool(false))
.value()) {
}
// the build function.
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
- CHECK(bf != nullptr) << "target.build." << target << " is not enabled";
- return (*bf)(mod, target->str());
+ CHECK(bf != nullptr) << build_f_name << " is not enabled";
+ return (*bf)(mod, target);
}
/*! \brief Helper class to serialize module */
}
};
-inline int DetectROCMComputeVersion(const std::string& target) {
- size_t pos = target.find("=gfx");
- if (pos != std::string::npos) {
- int value;
- std::stringstream is(target.substr(pos + 4));
- if (is >> value) return value;
+inline int DetectROCMComputeVersion(const Target& target) {
+ if (const Optional<String> mcpu = target->GetAttr<String>("mcpu")) {
+ std::string gfx = mcpu.value();
+ if (gfx.length() >= 3 && gfx.substr(0, 3) == "gfx") {
+ int version;
+ std::stringstream is(gfx.substr(3));
+ if (is >> version) {
+ return version;
+ }
+ }
+ LOG(FATAL) << "ValueError: Unrecognized -mcpu value: " << mcpu;
}
TVMContext tvm_ctx;
tvm_ctx.device_type = kDLROCM;
return 305;
}
-runtime::Module BuildAMDGPU(IRModule mod, std::string target) {
+Target UpdateTarget(const Target& original_target) {
+ Map<String, ObjectRef> target_config = original_target->Export();
+ UpdateTargetConfigKeyValueEntry("mtriple", "amdgcn-amd-amdhsa-hcc", &target_config, true);
+ UpdateTargetConfigKeyValueEntry("mcpu",
+ "gfx" + std::to_string(DetectROCMComputeVersion(original_target)),
+ &target_config, false);
+ if (DetectROCMApiVersion() < 305) {
+ // before ROCm 3.5 we needed code object v2, starting
+ // with 3.5 we need v3 (this argument disables v3)
+ Array<String> mattr;
+ if (target_config.count("mattr")) {
+ mattr = Downcast<Array<String>>(target_config["mattr"]);
+ }
+ mattr.push_back("-code-object-v3");
+ target_config.Set("mattr", mattr);
+ }
+ return Target::FromConfig(target_config);
+}
+
+runtime::Module BuildAMDGPU(IRModule mod, Target original_target) {
#if TVM_LLVM_VERSION < 90
LOG(FATAL) << "AMDGPU backend requires at least LLVM 9";
// Lower versions will crash when loading the bitcode, see
// issue #4087 for a discussion
#endif
InitializeLLVM();
- CHECK(target.length() >= 4 && target.substr(0, 4) == "rocm");
- std::ostringstream config;
- config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" << DetectROCMComputeVersion(target);
- if (DetectROCMApiVersion() < 305) {
- // before ROCm 3.5 we needed code object v2, starting
- // with 3.5 we need v3 (this argument disables v3)
- config << " -mattr=-code-object-v3 ";
- }
- config << target.substr(4, target.length() - 4);
- std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
+ Target target = UpdateTarget(original_target);
+ std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(target);
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
// careful: cg will hold a naked pointer reference to ctx, so it should
// have a shorter lifetime than the ctx.
#include "codegen_blob.h"
#include <tvm/runtime/module.h>
+#include <tvm/target/target.h>
#include <cstring>
std::pair<std::unique_ptr<llvm::Module>, std::shared_ptr<llvm::LLVMContext>> CodeGenBlob(
const std::string& data, bool system_lib, const std::string& target_triple) {
InitializeLLVM();
- std::string full_target_triple = std::string("-mtriple ") + target_triple;
- auto tm = GetLLVMTargetMachine(full_target_triple);
+ Target target = Target::Create("llvm -mtriple " + target_triple);
+ auto tm = GetLLVMTargetMachine(target);
auto triple = tm->getTargetTriple();
auto ctx = std::make_shared<llvm::LLVMContext>();
std::string module_name = "devc";
// Store full target string in metadata, because flags such as -mfloat-abi must be preserved for
// ModulePackImportsToLLVM.
module->addModuleFlag(llvm::Module::ModFlagBehavior::Override, "tvm_target",
- llvm::MDString::get(*ctx, full_target_triple));
+ llvm::MDString::get(*ctx, LLVMTargetToString(target)));
module->setDataLayout(tm->createDataLayout());
auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false);
auto* tvm_dev_mblob = new llvm::GlobalVariable(
} // namespace
-runtime::Module BuildHexagon(IRModule mod, std::string target_str) {
- if (target_str.empty()) {
- LOG(FATAL) << "Unknown or invalid target.";
- }
-
+runtime::Module BuildHexagon(IRModule mod, Target target) {
// Make sure all targets are registered. InitializeLLVM can be called
// multiple times, after the first call all subsequent calls are no-ops.
InitializeLLVM();
}
return vec;
};
- auto starts_with = [](const std::string& s, const std::string& p) {
- return !s.compare(0, p.size(), p);
- };
-
- std::vector<std::string> flags = split(target_str);
- std::string llvm_target_str, llvm_options_str = "llvm";
-
- for (const auto& s : flags) {
- if (starts_with(s, "-mattr=") || starts_with(s, "-mtriple=") || starts_with(s, "-mcpu=")) {
- llvm_target_str += " " + s;
- } else if (starts_with(s, "-llvm-options=")) {
- llvm_options_str += "," + s.substr(14 /*length of -llvm-options=*/);
- }
+ std::string llvm_options_str;
+ if (const Optional<String> llvm_options = target->GetAttr<String>("llvm-options")) {
+ llvm_options_str = "llvm," + llvm_options.value();
+ } else {
+ llvm_options_str = "llvm";
}
-
// Postprocess the LLVM options string: replace '@' with '=', and ',' with ' '.
for (int i = 0, e = llvm_options_str.size(); i != e; ++i) {
switch (llvm_options_str[i]) {
static bool CallOnce = (ProcessLLVMOptions(llvm_options_vec), true);
(void)CallOnce;
- std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(target_str);
+ std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(target);
std::unique_ptr<CodeGenHexagon> cg(new CodeGenHexagon());
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
cg->Init("TVMHexagonModule", tm.get(), ctx.get(), false, false, false);
export_abi);
}
-TVM_REGISTER_GLOBAL("target.build.hexagon").set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = BuildHexagon(args[0], args[1]);
-});
+TVM_REGISTER_GLOBAL("target.build.hexagon").set_body_typed(BuildHexagon);
} // namespace codegen
} // namespace tvm
}
}
-runtime::Module BuildNVPTX(IRModule mod, std::string target) {
+Target UpdateTarget(const Target& original_target, int compute_ver) {
+ Map<String, ObjectRef> target_config = original_target->Export();
+ UpdateTargetConfigKeyValueEntry("mtriple", "nvptx64-nvidia-cuda", &target_config, true);
+ UpdateTargetConfigKeyValueEntry("mcpu", "sm_" + std::to_string(compute_ver), &target_config,
+ false);
+ return Target::FromConfig(target_config);
+}
+
+runtime::Module BuildNVPTX(IRModule mod, Target original_target) {
InitializeLLVM();
- CHECK(target.length() >= 5 && target.substr(0, 5) == "nvptx");
int compute_ver = DetectCUDAComputeVersion();
- std::ostringstream config;
- config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_" << compute_ver
- << target.substr(5, target.length() - 5);
- std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
+ Target target = UpdateTarget(original_target, compute_ver);
+ std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(target);
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
// careful: cg will hold a naked pointer reference to ctx, so it should
// have a shorter lifetime than the ctx.
#include "llvm_common.h"
#include <dmlc/logging.h>
+#include <tvm/target/target.h>
#include <atomic>
#include <memory>
}
}
-void ParseLLVMTargetOptions(const std::string& target_str, std::string* triple, std::string* mcpu,
+void ParseLLVMTargetOptions(const Target& target, std::string* triple, std::string* mcpu,
std::string* mattr, llvm::TargetOptions* options) {
- // setup target triple
- size_t start = 0;
- if (target_str.length() >= 4 && target_str.substr(0, 4) == "llvm") {
- start = 4;
- }
// simple parser
triple->resize(0);
mcpu->resize(0);
mattr->resize(0);
-
bool soft_float_abi = false;
- std::string key, value;
- std::istringstream is(target_str.substr(start, target_str.length() - start));
- while (is >> key) {
- if (key == "-system-lib" || key == "-system-lib=0" || key == "-system-lib=1") {
- continue;
- }
- size_t pos = key.find('=');
- if (pos != std::string::npos) {
- CHECK_GE(key.length(), pos + 1) << "invalid argument " << key;
- value = key.substr(pos + 1, key.length() - 1);
- key = key.substr(0, pos);
- } else {
- CHECK(is >> value) << "Unspecified value for option " << key;
+ if (const Optional<String>& v = target->GetAttr<String>("mtriple")) {
+ *triple = v.value();
+ }
+ if (const Optional<String>& v = target->GetAttr<String>("mcpu")) {
+ *mcpu = v.value();
+ }
+ if (const Optional<Array<String>>& v = target->GetAttr<Array<String>>("mattr")) {
+ std::ostringstream os;
+ bool is_first = true;
+ for (const String& s : v.value()) {
+ if (!is_first) {
+ os << ',';
+ }
+ is_first = false;
+ os << s;
}
- if (key == "-mtriple") {
- *triple = value;
- } else if (key == "-mcpu") {
- *mcpu = value;
- } else if (key == "-mattr") {
- *mattr = value;
- } else if (key == "-mfloat-abi") {
- if (value == "hard") {
+ *mattr = os.str();
+ }
+ if (const Optional<String>& v = target->GetAttr<String>("mfloat-abi")) {
+ String value = v.value();
+ if (value == "hard") {
#if TVM_LLVM_VERSION < 60
- LOG(FATAL) << "-mfloat-abi hard is only supported for LLVM > 6.0";
+ LOG(FATAL) << "-mfloat-abi hard is only supported for LLVM > 6.0";
#endif
- soft_float_abi = false;
- } else if (value == "soft") {
- soft_float_abi = true;
- } else {
- LOG(FATAL) << "invalid -mfloat-abi option " << value;
- }
+ soft_float_abi = false;
+ } else if (value == "soft") {
+ soft_float_abi = true;
+ } else {
+ LOG(FATAL) << "invalid -mfloat-abi option " << value;
}
}
-
if (triple->length() == 0 || *triple == "default") {
*triple = llvm::sys::getDefaultTargetTriple();
}
}
}
-std::unique_ptr<llvm::TargetMachine> GetLLVMTargetMachine(const std::string& target_str,
- bool allow_null) {
+std::unique_ptr<llvm::TargetMachine> GetLLVMTargetMachine(const Target& target, bool allow_null) {
std::string target_triple, mcpu, mattr;
llvm::TargetOptions opt;
- ParseLLVMTargetOptions(target_str, &target_triple, &mcpu, &mattr, &opt);
+ ParseLLVMTargetOptions(target, &target_triple, &mcpu, &mattr, &opt);
if (target_triple.length() == 0 || target_triple == "default") {
target_triple = llvm::sys::getDefaultTargetTriple();
}
std::string err;
- const llvm::Target* target = llvm::TargetRegistry::lookupTarget(target_triple, err);
- if (target == nullptr) {
+ const llvm::Target* llvm_target = llvm::TargetRegistry::lookupTarget(target_triple, err);
+ if (llvm_target == nullptr) {
CHECK(allow_null) << err << " target_triple=" << target_triple;
return nullptr;
}
llvm::TargetMachine* tm =
- target->createTargetMachine(target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_);
+ llvm_target->createTargetMachine(target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_);
return std::unique_ptr<llvm::TargetMachine>(tm);
}
+std::string LLVMTargetToString(const Target& target) {
+ std::ostringstream os;
+ os << "llvm";
+ if (Optional<String> mtriple = target->GetAttr<String>("mtriple")) {
+ os << " -mtriple=" << mtriple.value();
+ }
+ if (Optional<String> mcpu = target->GetAttr<String>("mcpu")) {
+ os << " -mcpu=" << mcpu.value();
+ }
+ if (Optional<Array<String>> mattr = target->GetAttr<Array<String>>("mattr")) {
+ bool is_first = true;
+ os << " -mattr=";
+ for (const String& attr : mattr.value()) {
+ if (!is_first) {
+ os << ",";
+ }
+ is_first = false;
+ os << attr;
+ }
+ }
+ if (Optional<String> mfloat_abo = target->GetAttr<String>("mfloat-abi")) {
+ os << " -mfloat-abi=" << mfloat_abo.value();
+ }
+ return os.str();
+}
+
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
#include <utility>
namespace tvm {
+
+// The TVM target
+class Target;
+
namespace codegen {
/*!
/*!
* \brief Parse target options
- * \param target_str Target string, in format "llvm -mtriple=xxx -mcpu=xxx"
+ * \param target The TVM target
* \param triple Target triple
* \param mcpu cpu info
* \param options the options
* \param mattr The attributes
*/
-void ParseLLVMTargetOptions(const std::string& target_str, std::string* triple, std::string* mcpu,
+void ParseLLVMTargetOptions(const Target& target, std::string* triple, std::string* mcpu,
std::string* mattr, llvm::TargetOptions* options);
/*!
- * \brief Get target machine from target_str string.
- * \param target_str Target string, in format "llvm -mtriple=xxx -mcpu=xxx"
+ * \brief Get target machine from TVM target.
+ * \param target The TVM target
* \param allow_null Whether allow null to be returned.
* \return target machine
*/
-std::unique_ptr<llvm::TargetMachine> GetLLVMTargetMachine(const std::string& target_str,
+std::unique_ptr<llvm::TargetMachine> GetLLVMTargetMachine(const Target& target,
bool allow_null = false);
+/*!
+ * \brief Convert the TVM's LLVM target to string by extracting only relevant fields
+ * \param target The TVM target to be extracted
+ * \return The raw string format for the TVM LLVM target
+ */
+std::string LLVMTargetToString(const Target& target);
+
} // namespace codegen
} // namespace tvm
return "";
}
- void Init(const IRModule& mod, std::string target_str) {
+ void Init(const IRModule& mod, const Target& target) {
InitializeLLVM();
- tm_ = GetLLVMTargetMachine(target_str);
- auto target = Target::Create(target_str);
+ tm_ = GetLLVMTargetMachine(target);
bool system_lib = target->GetAttr<Bool>("system-lib").value_or(Bool(false));
bool target_c_runtime = (target->GetAttr<String>("runtime").value_or("") == kTvmRuntimeCrt);
ctx_ = std::make_shared<llvm::LLVMContext>();
module_ = cg->Finish();
module_->addModuleFlag(llvm::Module::Warning, "tvm_target",
- llvm::MDString::get(*ctx_, target_str));
+ llvm::MDString::get(*ctx_, LLVMTargetToString(target)));
module_->addModuleFlag(llvm::Module::Override, "Debug Info Version",
llvm::DEBUG_METADATA_VERSION);
LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors))
<< "LLVM module verification failed with the following errors: \n"
<< verify_errors.str();
- target_ = target_str;
+ target_ = target;
mptr_ = module_.get();
}
std::string msg = std::string(err.getMessage());
LOG(FATAL) << "Fail to load module: " << msg;
}
- std::string target_;
- llvm::Metadata* mtarget = module_->getModuleFlag("tvm_target");
- if (mtarget != nullptr) {
- llvm::MDString* pstr = llvm::dyn_cast<llvm::MDString>(mtarget);
+ std::string target_metadata;
+ llvm::Metadata* tvm_target = module_->getModuleFlag("tvm_target");
+ if (tvm_target != nullptr) {
+ llvm::MDString* pstr = llvm::dyn_cast<llvm::MDString>(tvm_target);
CHECK(pstr != nullptr);
- target_ = pstr->getString().str();
+ target_metadata = pstr->getString().str();
+ if (!(target_metadata.length() >= 4 && target_metadata.substr(0, 4) == "llvm")) {
+ target_metadata = "llvm " + target_metadata;
+ }
} else {
std::ostringstream os;
os << "llvm -mtriple " << module_->getTargetTriple();
- target_ = os.str();
+ target_metadata = os.str();
}
mptr_ = module_.get();
- tm_ = GetLLVMTargetMachine(target_);
+ tm_ = GetLLVMTargetMachine(Target::Create(target_metadata));
}
void LoadIR(const std::string& file_name) {
if (ee_) {
return;
}
+ if (!target_.defined()) {
+ target_ = Target::Create("llvm");
+ }
llvm::EngineBuilder builder(std::move(module_));
std::string triple, mcpu, mattr;
llvm::TargetOptions opt;
}
builder.setTargetOptions(opt);
auto tm = std::unique_ptr<llvm::TargetMachine>(builder.selectTarget());
- std::unique_ptr<llvm::TargetMachine> tm_sys = GetLLVMTargetMachine("llvm");
+ std::unique_ptr<llvm::TargetMachine> tm_sys = GetLLVMTargetMachine(Target::Create("llvm"));
if (tm_sys->getTargetTriple().getArch() != tm->getTargetTriple().getArch()) {
LOG(FATAL) << "Cannot run module, architecture mismatch "
<< " module=" << tm->getTargetTriple().str()
}
// The target configuration string
- std::string target_;
+ Target target_;
// JIT lock
std::mutex mutex_;
// execution engine
std::shared_ptr<llvm::LLVMContext> ctx_;
};
-unsigned LookupLLVMIntrinsic(const std::string& name) {
- return llvm::Function::lookupIntrinsicID(name);
-}
-
-TVM_REGISTER_GLOBAL("target.build.llvm").set_body_typed([](IRModule mod, std::string target) {
- auto n = make_object<LLVMModuleNode>();
- n->Init(mod, target);
- return runtime::Module(n);
+TVM_REGISTER_GLOBAL("target.build.llvm")
+ .set_body_typed([](IRModule mod, Target target) -> runtime::Module {
+ auto n = make_object<LLVMModuleNode>();
+ n->Init(mod, target);
+ return runtime::Module(n);
+ });
+
+TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate")
+ .set_body_typed([](std::string target_str, std::string module_name) -> runtime::Module {
+ Target target = Target::Create(target_str);
+ auto n = make_object<LLVMModuleNode>();
+ // Generate a LLVM module from an input target string
+ InitializeLLVM();
+ auto tm = GetLLVMTargetMachine(target);
+ auto ctx = std::make_shared<llvm::LLVMContext>();
+ std::unique_ptr<llvm::Module> module(new llvm::Module(module_name, *ctx));
+ // Use a default data layout and target triple
+ auto triple = tm->getTargetTriple();
+ module->setTargetTriple(triple.str());
+ module->setDataLayout(tm->createDataLayout());
+ n->Init(std::move(module), ctx);
+ return runtime::Module(n);
+ });
+
+TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id")
+ .set_body_typed([](std::string name) -> int64_t {
+ return static_cast<int64_t>(llvm::Function::lookupIntrinsicID(name));
+ });
+
+TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int {
+ return TVM_LLVM_VERSION / 10;
});
-TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate").set_body([](TVMArgs args, TVMRetValue* rv) {
- auto n = make_object<LLVMModuleNode>();
- auto target = args[0].operator std::string();
- auto module_name = args[1].operator std::string();
-
- // Generate a LLVM module from an input target string
- InitializeLLVM();
- auto tm = GetLLVMTargetMachine(target);
- auto ctx = std::make_shared<llvm::LLVMContext>();
- std::unique_ptr<llvm::Module> module(new llvm::Module(module_name, *ctx));
-
- // Use a default data layout and target triple
- auto triple = tm->getTargetTriple();
- module->setTargetTriple(triple.str());
- module->setDataLayout(tm->createDataLayout());
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll")
+ .set_body_typed([](std::string filename, std::string fmt) -> runtime::Module {
+ auto n = make_object<LLVMModuleNode>();
+ n->LoadIR(filename);
+ return runtime::Module(n);
+ });
+
+TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled")
+ .set_body_typed([](std::string target_str) -> bool {
+ InitializeLLVM();
+ Target target = Target::Create(target_str);
+ return (GetLLVMTargetMachine(target, true) != nullptr);
+ });
+
+TVM_REGISTER_GLOBAL("codegen.codegen_blob")
+ .set_body_typed([](std::string data, bool system_lib,
+ std::string target_triple) -> runtime::Module {
+ auto n = make_object<LLVMModuleNode>();
+ auto p = CodeGenBlob(data, system_lib, target_triple);
+ n->Init(std::move(p.first), p.second);
+ return runtime::Module(n);
+ });
- n->Init(std::move(module), ctx);
-
- *rv = runtime::Module(n);
-});
-
-TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id").set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = static_cast<int64_t>(LookupLLVMIntrinsic(args[0]));
-});
-
-TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body([](TVMArgs args, TVMRetValue* rv) {
- int major = TVM_LLVM_VERSION / 10;
- *rv = major;
-});
-
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll").set_body([](TVMArgs args, TVMRetValue* rv) {
- auto n = make_object<LLVMModuleNode>();
- n->LoadIR(args[0]);
- *rv = runtime::Module(n);
-});
-
-TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled").set_body([](TVMArgs args, TVMRetValue* rv) {
- InitializeLLVM();
- *rv = (GetLLVMTargetMachine(args[0], true) != nullptr);
-});
-
-TVM_REGISTER_GLOBAL("codegen.codegen_blob").set_body([](TVMArgs args, TVMRetValue* rv) {
- auto n = make_object<LLVMModuleNode>();
- auto p = CodeGenBlob(args[0].operator std::string(), args[1].operator bool(),
- args[2].operator std::string());
- n->Init(std::move(p.first), p.second);
- *rv = runtime::Module(n);
-});
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
return ptx;
}
-runtime::Module BuildCUDA(IRModule mod, std::string target) {
+runtime::Module BuildCUDA(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenCUDA cg;
namespace tvm {
namespace codegen {
-runtime::Module BuildAOCL(IRModule mod, std::string target_str, bool emulation) {
+runtime::Module BuildAOCL(IRModule mod, Target target, bool emulation) {
// Get code.
using tvm::runtime::Registry;
bool output_ssa = false;
std::string cmd = "aoc aocl.cl";
// AOCL supports fp64.
cmd += " -Dcl_khr_fp64";
- Target target = Target::Create(target_str);
Optional<String> device = target->GetAttr<String>("device");
if (device.defined()) {
cmd += " -board=" + device.value();
return AOCLModuleCreate(aocxbin, "aocx", ExtractFuncInfo(mod), code);
}
-TVM_REGISTER_GLOBAL("target.build.aocl").set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = BuildAOCL(args[0], args[1], false);
-});
+TVM_REGISTER_GLOBAL("target.build.aocl")
+ .set_body_typed([](IRModule mod, Target target) -> runtime::Module {
+ return BuildAOCL(mod, target, false);
+ });
-TVM_REGISTER_GLOBAL("target.build.build.aocl_sw_emu").set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = BuildAOCL(args[0], args[1], true);
-});
+TVM_REGISTER_GLOBAL("target.build.build.aocl_sw_emu")
+ .set_body_typed([](IRModule mod, Target target) -> runtime::Module {
+ return BuildAOCL(mod, target, true);
+ });
} // namespace codegen
} // namespace tvm
<< "}\n";
}
-runtime::Module BuildCHost(IRModule mod, const std::string& target_str) {
+runtime::Module BuildCHost(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
bool emit_asserts = false;
CodeGenCHost cg;
- auto target = Target::Create(target_str);
cg.Init(output_ssa, emit_asserts);
for (auto kv : mod->functions) {
return CSourceModuleCreate(code, "c");
}
-TVM_REGISTER_GLOBAL("target.build.c").set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = BuildCHost(args[0], args[1]);
-});
+TVM_REGISTER_GLOBAL("target.build.c").set_body_typed(BuildCHost);
} // namespace codegen
} // namespace tvm
}
}
-runtime::Module BuildMetal(IRModule mod) {
+runtime::Module BuildMetal(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenMetal cg;
return MetalModuleCreate(code, fmt, ExtractFuncInfo(mod), source);
}
-TVM_REGISTER_GLOBAL("target.build.metal").set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = BuildMetal(args[0]);
-});
+TVM_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal);
} // namespace codegen
} // namespace tvm
}
}
-runtime::Module BuildOpenCL(IRModule mod, std::string target) {
+runtime::Module BuildOpenCL(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenOpenCL cg;
PrintBinaryExpr(op, opstr, os, this);
}
-runtime::Module BuildSDAccel(IRModule mod, std::string target_str) {
+runtime::Module BuildSDAccel(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenVivadoHLS cg;
std::string xclbin;
if (const auto* f = Registry::Get("tvm_callback_sdaccel_compile")) {
- Target target = Target::Create(target_str);
String device = target->GetAttr<String>("device", "").value();
xclbin = (*f)(kernel_info, device).operator std::string();
} else {
spv_context ctx_;
};
-runtime::Module BuildSPIRV(IRModule mod, std::string target, bool webgpu_restriction) {
+runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction) {
using tvm::runtime::Registry;
using tvm::runtime::VulkanShader;
return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), code_data.str());
}
-TVM_REGISTER_GLOBAL("target.build.vulkan").set_body_typed([](IRModule mod, std::string target) {
+TVM_REGISTER_GLOBAL("target.build.vulkan").set_body_typed([](IRModule mod, Target target) {
return BuildSPIRV(mod, target, false);
});
-TVM_REGISTER_GLOBAL("target.build.webgpu").set_body_typed([](IRModule mod, std::string target) {
+TVM_REGISTER_GLOBAL("target.build.webgpu").set_body_typed([](IRModule mod, Target target) {
return BuildSPIRV(mod, target, true);
});
this->Push(op->body);
}
-runtime::Module BuildStackVM(const IRModule& mod, const std::string& target) {
+runtime::Module BuildStackVM(IRModule mod, Target target) {
std::unordered_map<std::string, StackVM> fmap;
std::string entry_func;
return result;
}
+Map<String, ObjectRef> TargetNode::Export() const {
+ Map<String, ObjectRef> result = {
+ {"kind", this->kind->name},
+ {"tag", this->tag},
+ {"keys", this->keys},
+ };
+ for (const auto& kv : attrs) {
+ result.Set(kv.first, kv.second);
+ }
+ return result;
+}
+
const std::string& TargetNode::str() const {
if (str_repr_.empty()) {
std::ostringstream os;
TVM_REGISTER_GLOBAL("target.TargetFromConfig").set_body_typed(Target::FromConfig);
+TVM_REGISTER_GLOBAL("target.TargetExport")
+ .set_body_typed([](Target target) -> Map<String, ObjectRef> { return target->Export(); });
+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TargetNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const TargetNode*>(node.get());
- p->stream << op->str();
+ const auto* target = node.as<TargetNode>();
+ CHECK(target);
+ p->stream << target->str();
});
namespace target {
.add_attr_option<Integer>("max_num_threads", Integer(1024))
.add_attr_option<Integer>("thread_warp_size", Integer(32))
.add_attr_option<String>("mcpu")
+ .add_attr_option<String>("mtriple")
.set_default_keys({"cuda", "gpu"})
.set_device_type(kDLGPU);
auto module = build(lowered, target, Target());
auto mali_target = Target::Create("opencl -model=Mali-T860MP4@800Mhz -device=mali");
- CHECK_EQ(
- mali_target->str(),
- "opencl -keys=mali,opencl,gpu -device=mali -max_num_threads=256 -model=Mali-T860MP4@800Mhz");
+ CHECK_EQ(mali_target->kind->name, "opencl");
+ CHECK_EQ(mali_target->keys.size(), 3);
+ CHECK_EQ(mali_target->keys[0], "mali");
+ CHECK_EQ(mali_target->keys[1], "opencl");
+ CHECK_EQ(mali_target->keys[2], "gpu");
+ CHECK_EQ(mali_target->GetAttr<String>("device").value(), "mali");
+ CHECK_EQ(mali_target->GetAttr<String>("model").value(), "Mali-T860MP4@800Mhz");
+ CHECK_EQ(mali_target->GetAttr<Integer>("max_num_threads").value(), 256);
}
TEST(BuildModule, Heterogeneous) {