Add .hsaco save/load for ROCm target (#3852)
authorPeter Yeh <petrex@users.noreply.github.com>
Sat, 7 Sep 2019 03:41:35 +0000 (20:41 -0700)
committermasahi <masahi129@gmail.com>
Sat, 7 Sep 2019 03:41:35 +0000 (12:41 +0900)
fix lld

python/tvm/contrib/rocm.py
src/codegen/llvm/codegen_amdgpu.cc
tutorials/tensor_expr_get_started.py

index 5db8a38..df4656b 100644 (file)
@@ -46,7 +46,7 @@ def find_lld(required=True):
         major = codegen.llvm_version_major()
         lld_list += ["ld.lld-%d.0" % major]
         lld_list += ["ld.lld-%d" % major]
-    lld_list += ["lld"]
+    lld_list += ["ld.lld"]
     valid_list = [util.which(x) for x in lld_list]
     valid_list = [x for x in valid_list if x]
     if not valid_list and required:
index 396ae59..2784090 100644 (file)
@@ -170,8 +170,8 @@ inline int DetectROCMComputeVersion(const std::string& target) {
       return val.operator int();
     }
   }
-  LOG(WARNING) << "Cannot find -mcpu to specify rocm compute version assume gfx803";
-  return 803;
+  LOG(WARNING) << "Cannot find -mcpu to specify rocm compute version assume gfx900";
+  return 900;
 }
 
 runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
index 1b5eabc..efdd499 100644 (file)
@@ -33,7 +33,7 @@ import numpy as np
 # Global declarations of environment.
 
 tgt_host="llvm"
-# Change it to respective GPU if gpu is enabled Ex: cuda, opencl
+# Change it to respective GPU if gpu is enabled Ex: cuda, opencl, rocm
 tgt="cuda"
 
 ######################################################################
@@ -113,7 +113,7 @@ bx, tx = s[C].split(C.op.axis[0], factor=64)
 # compute grid. These are GPU specific constructs that allow us
 # to generate code that runs on GPU.
 #
-if tgt == "cuda" or tgt.startswith('opencl'):
+if tgt == "cuda" or tgt == "rocm" or tgt.startswith('opencl'):
   s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
   s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
 
@@ -168,7 +168,7 @@ tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 #
 # The following code fetches the device module and prints the content code.
 #
-if tgt == "cuda" or tgt.startswith('opencl'):
+if tgt == "cuda" or tgt == "rocm" or tgt.startswith('opencl'):
     dev_module = fadd.imported_modules[0]
     print("-----GPU code-----")
     print(dev_module.get_source())
@@ -212,6 +212,8 @@ temp = util.tempdir()
 fadd.save(temp.relpath("myadd.o"))
 if tgt == "cuda":
     fadd.imported_modules[0].save(temp.relpath("myadd.ptx"))
+if tgt == "rocm":
+    fadd.imported_modules[0].save(temp.relpath("myadd.hsaco"))
 if tgt.startswith('opencl'):
     fadd.imported_modules[0].save(temp.relpath("myadd.cl"))
 cc.create_shared(temp.relpath("myadd.so"), [temp.relpath("myadd.o")])
@@ -238,6 +240,10 @@ if tgt == "cuda":
     fadd1_dev = tvm.module.load(temp.relpath("myadd.ptx"))
     fadd1.import_module(fadd1_dev)
 
+if tgt == "rocm":
+    fadd1_dev = tvm.module.load(temp.relpath("myadd.hsaco"))
+    fadd1.import_module(fadd1_dev)
+
 if tgt.startswith('opencl'):
     fadd1_dev = tvm.module.load(temp.relpath("myadd.cl"))
     fadd1.import_module(fadd1_dev)