[RUNTIME] Add device query for AMD GcnArch (#4341)
authorPeter Yeh <petrex@users.noreply.github.com>
Fri, 15 Nov 2019 04:43:47 +0000 (20:43 -0800)
committermasahi <masahi129@gmail.com>
Fri, 15 Nov 2019 04:43:47 +0000 (13:43 +0900)
* add gcnArch query

* kGcnArch query for cuda is a no-op

include/tvm/runtime/device_api.h
src/codegen/llvm/codegen_amdgpu.cc
src/runtime/cuda/cuda_device_api.cc
src/runtime/metal/metal_device_api.mm
src/runtime/opencl/opencl_device_api.cc
src/runtime/opengl/opengl_device_api.cc
src/runtime/rocm/rocm_device_api.cc
src/runtime/vulkan/vulkan.cc

index bb362dc..4b0fcd3 100644 (file)
@@ -42,7 +42,8 @@ enum DeviceAttrKind : int {
   kDeviceName = 5,
   kMaxClockRate = 6,
   kMultiProcessorCount = 7,
-  kMaxThreadDimensions = 8
+  kMaxThreadDimensions = 8,
+  kGcnArch = 9
 };
 
 /*! \brief Number of bytes each allocation must align to */
index 28b2deb..19179c3 100644 (file)
@@ -174,7 +174,7 @@ inline int DetectROCMComputeVersion(const std::string& target) {
     TVMRetValue val;
     api->GetAttr(tvm_ctx, tvm::runtime::kExist, &val);
     if (val.operator int() == 1) {
-      tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val);
+      tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kGcnArch, &val);
       return val.operator int();
     }
   }
index 87ee404..a504eda 100644 (file)
@@ -105,6 +105,7 @@ class CUDADeviceAPI final : public DeviceAPI {
         *rv = ss.str();
         return;
       }
+      case kGcnArch: return;
     }
     *rv = value;
   }
index e38329a..d319e50 100644 (file)
@@ -63,6 +63,7 @@ void MetalWorkspace::GetAttr(
     case kMultiProcessorCount: return;
     case kMaxThreadDimensions: return;
     case kExist: break;
+    case kGcnArch: return; 
   }
 }
 
index 1e6af53..882ee83 100644 (file)
@@ -114,6 +114,7 @@ void OpenCLWorkspace::GetAttr(
       *rv = ss.str();
       break;
     }
+    case kGcnArch: return;
   }
 }
 
index db06552..1b1487e 100644 (file)
@@ -117,6 +117,7 @@ void OpenGLWorkspace::GetAttr(
     case kMaxClockRate: return;
     case kMultiProcessorCount: return;
     case kMaxThreadDimensions: return;
+    case kGcnArch: return;
   }
 }
 
index cff72f5..c49af89 100644 (file)
 
 #include <dmlc/logging.h>
 #include <dmlc/thread_local.h>
-#include <tvm/runtime/registry.h>
 #include <hip/hip_runtime_api.h>
 #include <hsa/hsa.h>
+#include <tvm/runtime/registry.h>
+#include "../../../include/tvm/runtime/device_api.h"
 #include "rocm_common.h"
 
 namespace tvm {
@@ -62,16 +63,17 @@ class ROCMDeviceAPI final : public DeviceAPI {
         break;
       }
       case kMaxSharedMemoryPerBlock: return;
-      case kComputeVersion: {
+      case kComputeVersion:
+      case kDeviceName: return;
+      case kMaxClockRate: return;
+      case kMultiProcessorCount: return;
+      case kMaxThreadDimensions: return;
+      case kGcnArch: {
         hipDeviceProp_t prop;
         ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
         *rv = prop.gcnArch;
         return;
       }
-      case kDeviceName: return;
-      case kMaxClockRate: return;
-      case kMultiProcessorCount: return;
-      case kMaxThreadDimensions: return;
     }
     *rv = value;
   }
index daf4ae7..b14260e 100644 (file)
@@ -398,6 +398,8 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue*
       break;
     case kMaxThreadDimensions:
       break;
+    case kGcnArch:
+      return;
   }
 }