const std::shared_ptr<ModuleNode>& sptr_to_self) final;
+ void SaveToFile(const std::string& file_name,
+ const std::string& format) final {
+ std::string fmt = GetFileFormat(file_name, format);
+ std::string meta_file = GetMetaFilePath(file_name);
+ // note: llvm and asm formats are not laodable, so we don't save them
+ CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
+ SaveMetaDataToFile(meta_file, fmap_);
+ SaveBinaryToFile(file_name, data_);
+ }
+
void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(fmt_);
stream->Write(fmap_);
return Module(n);
}
+Module ROCMModuleLoadFile(const std::string& file_name,
+ const std::string& format) {
+ std::string data;
+ std::unordered_map<std::string, FunctionInfo> fmap;
+ std::string fmt = GetFileFormat(file_name, format);
+ std::string meta_file = GetMetaFilePath(file_name);
+ LoadBinaryFromFile(file_name, &data);
+ LoadMetaDataFromFile(meta_file, &fmap);
+ return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string());
+}
+
Module ROCMModuleLoadBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
std::string data;
TVM_REGISTER_GLOBAL("module.loadbinary_hip")
.set_body_typed(ROCMModuleLoadBinary);
+
+
+TVM_REGISTER_GLOBAL("module.loadfile_hsaco")
+.set_body_typed(ROCMModuleLoadFile);
+
+TVM_REGISTER_GLOBAL("module.loadfile_hip")
+.set_body_typed(ROCMModuleLoadFile);
} // namespace runtime
} // namespace tvm
return
if not tvm.module.enabled(host):
return
- fmt = "ptx" if device == "cuda" else device
+ if device == "cuda":
+ fmt = "ptx"
+ elif device == "rocm":
+ fmt = "hsaco"
+ else:
+ fmt = device
mhost = tvm.codegen.build_module(fsplits[0], host)
mdev = tvm.codegen.build_module(fsplits[1:], device)
temp = util.tempdir()
check_module_save("cuda", host="stackvm")
check_target("nvptx", host="llvm")
check_target("vulkan", host="llvm")
- check_target("rocm", host="llvm")
check_module_save("vulkan", host="stackvm")
+ check_target("rocm", host="llvm")
+ check_module_save("rocm", host="llvm")
if __name__ == "__main__":