From e8c6adc6fb700c6dd181b5b3f059c135d5fee6d5 Mon Sep 17 00:00:00 2001 From: Peter Yeh Date: Fri, 6 Sep 2019 20:41:35 -0700 Subject: [PATCH] Add .hsaco save/load for ROCm target (#3852) fix lld --- python/tvm/contrib/rocm.py | 2 +- src/codegen/llvm/codegen_amdgpu.cc | 4 ++-- tutorials/tensor_expr_get_started.py | 12 +++++++++--- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/python/tvm/contrib/rocm.py b/python/tvm/contrib/rocm.py index 5db8a38..df4656b 100644 --- a/python/tvm/contrib/rocm.py +++ b/python/tvm/contrib/rocm.py @@ -46,7 +46,7 @@ def find_lld(required=True): major = codegen.llvm_version_major() lld_list += ["ld.lld-%d.0" % major] lld_list += ["ld.lld-%d" % major] - lld_list += ["lld"] + lld_list += ["ld.lld"] valid_list = [util.which(x) for x in lld_list] valid_list = [x for x in valid_list if x] if not valid_list and required: diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc index 396ae59..2784090 100644 --- a/src/codegen/llvm/codegen_amdgpu.cc +++ b/src/codegen/llvm/codegen_amdgpu.cc @@ -170,8 +170,8 @@ inline int DetectROCMComputeVersion(const std::string& target) { return val.operator int(); } } - LOG(WARNING) << "Cannot find -mcpu to specify rocm compute version assume gfx803"; - return 803; + LOG(WARNING) << "Cannot find -mcpu to specify rocm compute version assume gfx900"; + return 900; } runtime::Module BuildAMDGPU(Array funcs, std::string target) { diff --git a/tutorials/tensor_expr_get_started.py b/tutorials/tensor_expr_get_started.py index 1b5eabc..efdd499 100644 --- a/tutorials/tensor_expr_get_started.py +++ b/tutorials/tensor_expr_get_started.py @@ -33,7 +33,7 @@ import numpy as np # Global declarations of environment. tgt_host="llvm" -# Change it to respective GPU if gpu is enabled Ex: cuda, opencl +# Change it to respective GPU if gpu is enabled Ex: cuda, opencl, rocm tgt="cuda" ###################################################################### @@ -113,7 +113,7 @@ bx, tx = s[C].split(C.op.axis[0], factor=64) # compute grid. These are GPU specific constructs that allow us # to generate code that runs on GPU. # -if tgt == "cuda" or tgt.startswith('opencl'): +if tgt == "cuda" or tgt == "rocm" or tgt.startswith('opencl'): s[C].bind(bx, tvm.thread_axis("blockIdx.x")) s[C].bind(tx, tvm.thread_axis("threadIdx.x")) @@ -168,7 +168,7 @@ tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) # # The following code fetches the device module and prints the content code. # -if tgt == "cuda" or tgt.startswith('opencl'): +if tgt == "cuda" or tgt == "rocm" or tgt.startswith('opencl'): dev_module = fadd.imported_modules[0] print("-----GPU code-----") print(dev_module.get_source()) @@ -212,6 +212,8 @@ temp = util.tempdir() fadd.save(temp.relpath("myadd.o")) if tgt == "cuda": fadd.imported_modules[0].save(temp.relpath("myadd.ptx")) +if tgt == "rocm": + fadd.imported_modules[0].save(temp.relpath("myadd.hsaco")) if tgt.startswith('opencl'): fadd.imported_modules[0].save(temp.relpath("myadd.cl")) cc.create_shared(temp.relpath("myadd.so"), [temp.relpath("myadd.o")]) @@ -238,6 +240,10 @@ if tgt == "cuda": fadd1_dev = tvm.module.load(temp.relpath("myadd.ptx")) fadd1.import_module(fadd1_dev) +if tgt == "rocm": + fadd1_dev = tvm.module.load(temp.relpath("myadd.hsaco")) + fadd1.import_module(fadd1_dev) + if tgt.startswith('opencl'): fadd1_dev = tvm.module.load(temp.relpath("myadd.cl")) fadd1.import_module(fadd1_dev) -- 2.7.4