From 0235d2838ad61038d1182136f7b190bc9cb753b5 Mon Sep 17 00:00:00 2001 From: Peter Yeh Date: Thu, 14 Nov 2019 20:43:47 -0800 Subject: [PATCH] [RUNTIME] Add device query for AMD GcnArch (#4341) * add gcnArch query * kGcnArch query for cuda is a no-op --- include/tvm/runtime/device_api.h | 3 ++- src/codegen/llvm/codegen_amdgpu.cc | 2 +- src/runtime/cuda/cuda_device_api.cc | 1 + src/runtime/metal/metal_device_api.mm | 1 + src/runtime/opencl/opencl_device_api.cc | 1 + src/runtime/opengl/opengl_device_api.cc | 1 + src/runtime/rocm/rocm_device_api.cc | 14 ++++++++------ src/runtime/vulkan/vulkan.cc | 2 ++ 8 files changed, 17 insertions(+), 8 deletions(-) diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index bb362dc..4b0fcd3 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -42,7 +42,8 @@ enum DeviceAttrKind : int { kDeviceName = 5, kMaxClockRate = 6, kMultiProcessorCount = 7, - kMaxThreadDimensions = 8 + kMaxThreadDimensions = 8, + kGcnArch = 9 }; /*! \brief Number of bytes each allocation must align to */ diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc index 28b2deb..19179c3 100644 --- a/src/codegen/llvm/codegen_amdgpu.cc +++ b/src/codegen/llvm/codegen_amdgpu.cc @@ -174,7 +174,7 @@ inline int DetectROCMComputeVersion(const std::string& target) { TVMRetValue val; api->GetAttr(tvm_ctx, tvm::runtime::kExist, &val); if (val.operator int() == 1) { - tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val); + tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kGcnArch, &val); return val.operator int(); } } diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 87ee404..a504eda 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -105,6 +105,7 @@ class CUDADeviceAPI final : public DeviceAPI { *rv = ss.str(); return; } + case kGcnArch: return; } *rv = value; } diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index e38329a..d319e50 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -63,6 +63,7 @@ void MetalWorkspace::GetAttr( case kMultiProcessorCount: return; case kMaxThreadDimensions: return; case kExist: break; + case kGcnArch: return; } } diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 1e6af53..882ee83 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -114,6 +114,7 @@ void OpenCLWorkspace::GetAttr( *rv = ss.str(); break; } + case kGcnArch: return; } } diff --git a/src/runtime/opengl/opengl_device_api.cc b/src/runtime/opengl/opengl_device_api.cc index db06552..1b1487e 100644 --- a/src/runtime/opengl/opengl_device_api.cc +++ b/src/runtime/opengl/opengl_device_api.cc @@ -117,6 +117,7 @@ void OpenGLWorkspace::GetAttr( case kMaxClockRate: return; case kMultiProcessorCount: return; case kMaxThreadDimensions: return; + case kGcnArch: return; } } diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index cff72f5..c49af89 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -26,9 +26,10 @@ #include #include -#include #include #include +#include +#include "../../../include/tvm/runtime/device_api.h" #include "rocm_common.h" namespace tvm { @@ -62,16 +63,17 @@ class ROCMDeviceAPI final : public DeviceAPI { break; } case kMaxSharedMemoryPerBlock: return; - case kComputeVersion: { + case kComputeVersion: + case kDeviceName: return; + case kMaxClockRate: return; + case kMultiProcessorCount: return; + case kMaxThreadDimensions: return; + case kGcnArch: { hipDeviceProp_t prop; ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id)); *rv = prop.gcnArch; return; } - case kDeviceName: return; - case kMaxClockRate: return; - case kMultiProcessorCount: return; - case kMaxThreadDimensions: return; } *rv = value; } diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index daf4ae7..b14260e 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -398,6 +398,8 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* break; case kMaxThreadDimensions: break; + case kGcnArch: + return; } } -- 2.7.4