ROCm: Add SaveToFile and LoadFile (#3665)
authorThomas Viehmann <tv.code@beamnet.de>
Tue, 30 Jul 2019 14:54:16 +0000 (16:54 +0200)
committermasahi <masahi129@gmail.com>
Tue, 30 Jul 2019 14:54:16 +0000 (23:54 +0900)
...and add rocm module_save to the tests.

src/runtime/rocm/rocm_module.cc
tests/python/unittest/test_codegen_device.py

index 0336bae..96d1948 100644 (file)
@@ -71,6 +71,16 @@ class ROCMModuleNode : public runtime::ModuleNode {
       const std::shared_ptr<ModuleNode>& sptr_to_self) final;
 
 
+  void SaveToFile(const std::string& file_name,
+                  const std::string& format) final {
+    std::string fmt = GetFileFormat(file_name, format);
+    std::string meta_file = GetMetaFilePath(file_name);
+    // note: llvm and asm formats are not laodable, so we don't save them
+    CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
+    SaveMetaDataToFile(meta_file, fmap_);
+    SaveBinaryToFile(file_name, data_);
+  }
+
   void SaveToBinary(dmlc::Stream* stream) final {
     stream->Write(fmt_);
     stream->Write(fmap_);
@@ -230,6 +240,17 @@ Module ROCMModuleCreate(
   return Module(n);
 }
 
+Module ROCMModuleLoadFile(const std::string& file_name,
+                          const std::string& format) {
+  std::string data;
+  std::unordered_map<std::string, FunctionInfo> fmap;
+  std::string fmt = GetFileFormat(file_name, format);
+  std::string meta_file = GetMetaFilePath(file_name);
+  LoadBinaryFromFile(file_name, &data);
+  LoadMetaDataFromFile(meta_file, &fmap);
+  return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string());
+}
+
 Module ROCMModuleLoadBinary(void* strm) {
   dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
   std::string data;
@@ -248,5 +269,12 @@ TVM_REGISTER_GLOBAL("module.loadbinary_hsaco")
 
 TVM_REGISTER_GLOBAL("module.loadbinary_hip")
 .set_body_typed(ROCMModuleLoadBinary);
+
+
+TVM_REGISTER_GLOBAL("module.loadfile_hsaco")
+.set_body_typed(ROCMModuleLoadFile);
+
+TVM_REGISTER_GLOBAL("module.loadfile_hip")
+.set_body_typed(ROCMModuleLoadFile);
 }  // namespace runtime
 }  // namespace tvm
index 9532975..6cb424c 100644 (file)
@@ -76,7 +76,12 @@ def test_add_pipeline():
             return
         if not tvm.module.enabled(host):
             return
-        fmt = "ptx" if device == "cuda" else device
+        if device == "cuda":
+            fmt = "ptx"
+        elif device == "rocm":
+            fmt = "hsaco"
+        else:
+            fmt = device
         mhost = tvm.codegen.build_module(fsplits[0], host)
         mdev = tvm.codegen.build_module(fsplits[1:], device)
         temp = util.tempdir()
@@ -99,8 +104,9 @@ def test_add_pipeline():
     check_module_save("cuda", host="stackvm")
     check_target("nvptx", host="llvm")
     check_target("vulkan", host="llvm")
-    check_target("rocm", host="llvm")
     check_module_save("vulkan", host="stackvm")
+    check_target("rocm", host="llvm")
+    check_module_save("rocm", host="llvm")
 
 
 if __name__ == "__main__":