proper device query through rocm api (#4305)
authorPeter Yeh <petrex@users.noreply.github.com>
Sat, 16 Nov 2019 06:39:44 +0000 (22:39 -0800)
committermasahi <masahi129@gmail.com>
Sat, 16 Nov 2019 06:39:44 +0000 (15:39 +0900)
src/runtime/rocm/rocm_device_api.cc

index c49af89..a599dbd 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  * \file rocm_device_api.cc
  * \brief GPU specific API
  */
-#include <tvm/runtime/device_api.h>
-
 #include <dmlc/logging.h>
 #include <dmlc/thread_local.h>
 #include <hip/hip_runtime_api.h>
 #include <hsa/hsa.h>
+#include <tvm/runtime/device_api.h>
 #include <tvm/runtime/registry.h>
-#include "../../../include/tvm/runtime/device_api.h"
+
 #include "rocm_common.h"
 
 namespace tvm {
@@ -55,19 +54,57 @@ class ROCMDeviceAPI final : public DeviceAPI {
         break;
       }
       case kMaxThreadsPerBlock: {
-        value = 1024;
+        ROCM_CALL(hipDeviceGetAttribute(
+            &value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id));
         break;
       }
       case kWarpSize: {
-        value = 64;
+        ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeWarpSize,
+                                        ctx.device_id));
         break;
       }
-      case kMaxSharedMemoryPerBlock: return;
-      case kComputeVersion:
-      case kDeviceName: return;
-      case kMaxClockRate: return;
-      case kMultiProcessorCount: return;
-      case kMaxThreadDimensions: return;
+      case kMaxSharedMemoryPerBlock: {
+        ROCM_CALL(hipDeviceGetAttribute(
+            &value, hipDeviceAttributeMaxSharedMemoryPerBlock, ctx.device_id));
+        break;
+      }
+      case kComputeVersion: {
+        std::ostringstream os;
+        ROCM_CALL(hipDeviceGetAttribute(
+            &value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id));
+        os << value << ".";
+        ROCM_CALL(hipDeviceGetAttribute(
+            &value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id));
+        os << value;
+        *rv = os.str();
+        return;
+      }
+      case kDeviceName:
+        return;
+      case kMaxClockRate: {
+        ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeClockRate,
+                                        ctx.device_id));
+        break;
+      }
+      case kMultiProcessorCount: {
+        ROCM_CALL(hipDeviceGetAttribute(
+            &value, hipDeviceAttributeMultiprocessorCount, ctx.device_id));
+        break;
+      }
+      case kMaxThreadDimensions: {
+        int dims[3];
+        ROCM_CALL(hipDeviceGetAttribute(
+            &dims[0], hipDeviceAttributeMaxBlockDimX, ctx.device_id));
+        ROCM_CALL(hipDeviceGetAttribute(
+            &dims[1], hipDeviceAttributeMaxBlockDimY, ctx.device_id));
+        ROCM_CALL(hipDeviceGetAttribute(
+            &dims[2], hipDeviceAttributeMaxBlockDimZ, ctx.device_id));
+
+        std::stringstream ss;
+        ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]";
+        *rv = ss.str();
+        return;
+      }
       case kGcnArch: {
         hipDeviceProp_t prop;
         ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
@@ -77,14 +114,11 @@ class ROCMDeviceAPI final : public DeviceAPI {
     }
     *rv = value;
   }
-  void* AllocDataSpace(TVMContext ctx,
-                       size_t nbytes,
-                       size_t alignment,
+  void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
                        TVMType type_hint) final {
     ROCM_CALL(hipSetDevice(ctx.device_id));
-    CHECK_EQ(256 % alignment, 0U)
-        << "ROCM space is aligned at 256 bytes";
-    void *ret;
+    CHECK_EQ(256 % alignment, 0U) << "ROCM space is aligned at 256 bytes";
+    void* ret;
     ROCM_CALL(hipMalloc(&ret, nbytes));
     return ret;
   }
@@ -94,14 +128,9 @@ class ROCMDeviceAPI final : public DeviceAPI {
     ROCM_CALL(hipFree(ptr));
   }
 
-  void CopyDataFromTo(const void* from,
-                      size_t from_offset,
-                      void* to,
-                      size_t to_offset,
-                      size_t size,
-                      TVMContext ctx_from,
-                      TVMContext ctx_to,
-                      TVMType type_hint,
+  void CopyDataFromTo(const void* from, size_t from_offset, void* to,
+                      size_t to_offset, size_t size, TVMContext ctx_from,
+                      TVMContext ctx_to, TVMType type_hint,
                       TVMStreamHandle stream) final {
     hipStream_t hip_stream = static_cast<hipStream_t>(stream);
     from = static_cast<const char*>(from) + from_offset;
@@ -111,14 +140,15 @@ class ROCMDeviceAPI final : public DeviceAPI {
       if (ctx_from.device_id == ctx_to.device_id) {
         GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream);
       } else {
-        hipMemcpyPeerAsync(to, ctx_to.device_id,
-                            from, ctx_from.device_id,
-                            size, hip_stream);
+        hipMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size,
+                           hip_stream);
       }
-    } else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU) {
+    } else if (ctx_from.device_type == kDLROCM &&
+               ctx_to.device_type == kDLCPU) {
       ROCM_CALL(hipSetDevice(ctx_from.device_id));
       GPUCopy(from, to, size, hipMemcpyDeviceToHost, hip_stream);
-    } else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLROCM) {
+    } else if (ctx_from.device_type == kDLCPU &&
+               ctx_to.device_type == kDLROCM) {
       ROCM_CALL(hipSetDevice(ctx_to.device_id));
       GPUCopy(from, to, size, hipMemcpyHostToDevice, hip_stream);
     } else {
@@ -132,8 +162,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
   }
 
   void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
-    ROCMThreadEntry::ThreadLocal()
-        ->stream = static_cast<hipStream_t>(stream);
+    ROCMThreadEntry::ThreadLocal()->stream = static_cast<hipStream_t>(stream);
   }
 
   void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
@@ -151,11 +180,8 @@ class ROCMDeviceAPI final : public DeviceAPI {
   }
 
  private:
-  static void GPUCopy(const void* from,
-                      void* to,
-                      size_t size,
-                      hipMemcpyKind kind,
-                      hipStream_t stream) {
+  static void GPUCopy(const void* from, void* to, size_t size,
+                      hipMemcpyKind kind, hipStream_t stream) {
     if (stream != 0) {
       ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream));
     } else {
@@ -166,19 +192,16 @@ class ROCMDeviceAPI final : public DeviceAPI {
 
 typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore;
 
-ROCMThreadEntry::ROCMThreadEntry()
-    : pool(kDLROCM, ROCMDeviceAPI::Global()) {
-}
+ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {}
 
 ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() {
   return ROCMThreadStore::Get();
 }
 
 TVM_REGISTER_GLOBAL("device_api.rocm")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    DeviceAPI* ptr = ROCMDeviceAPI::Global().get();
-    *rv = static_cast<void*>(ptr);
-  });
-
+    .set_body([](TVMArgs args, TVMRetValue* rv) {
+      DeviceAPI* ptr = ROCMDeviceAPI::Global().get();
+      *rv = static_cast<void*>(ptr);
+    });
 }  // namespace runtime
 }  // namespace tvm