Two small fixes to AMDCPU codegen for LLVM 10+ and ROCm 3.5+ (#5920)
authorThomas Viehmann <tv.code@beamnet.de>
Thu, 25 Jun 2020 16:59:12 +0000 (18:59 +0200)
committerGitHub <noreply@github.com>
Thu, 25 Jun 2020 16:59:12 +0000 (09:59 -0700)
- For LLVM 10+ we need to avoid calling Align with 0, or else
  we get a crash.
- For ROCm 3.5+ we need to use code object 3 (the default in LLVM 9+)
  but for ROCm < 3.5 we want the code object 2.
- As we want to separate codegen from the API, we need to add
  a device api query for the version.
  But every one else wants now one, too. (But I only filled it
  in for CUDA for now.)
- I'm throwing in an addition of kMaxRegistersPerBlock for ROCm.
  This was introduced for CUDA in #5898.

include/tvm/runtime/device_api.h
src/runtime/cuda/cuda_device_api.cc
src/runtime/metal/metal_device_api.mm
src/runtime/opencl/opencl_device_api.cc
src/runtime/rocm/rocm_common.h
src/runtime/rocm/rocm_device_api.cc
src/runtime/vulkan/vulkan.cc
src/target/llvm/codegen_amdgpu.cc

index 3cf5566..c6a2ce3 100644 (file)
@@ -45,7 +45,8 @@ enum DeviceAttrKind : int {
   kMultiProcessorCount = 7,
   kMaxThreadDimensions = 8,
   kMaxRegistersPerBlock = 9,
-  kGcnArch = 10
+  kGcnArch = 10,
+  kApiVersion = 11
 };
 
 /*! \brief Number of bytes each allocation must align to */
index ccd8e91..14444c9 100644 (file)
@@ -98,6 +98,10 @@ class CUDADeviceAPI final : public DeviceAPI {
       }
       case kGcnArch:
         return;
+      case kApiVersion: {
+        *rv = CUDA_VERSION;
+        return;
+      }
     }
     *rv = value;
   }
index a64f35c..f2a2930 100644 (file)
@@ -69,6 +69,8 @@ void MetalWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* r
       return;
     case kGcnArch:
       return;
+    case kApiVersion:
+      return;
   }
 }
 
index 72d03fb..5753c1d 100644 (file)
@@ -111,6 +111,8 @@ void OpenCLWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue*
       return;
     case kGcnArch:
       return;
+    case kApiVersion:
+      return;
   }
 }
 
index 2e637f5..6ed9bcc 100644 (file)
@@ -25,6 +25,7 @@
 #define TVM_RUNTIME_ROCM_ROCM_COMMON_H_
 
 #include <hip/hip_runtime_api.h>
+#include <hip/hip_version.h>
 #include <tvm/runtime/packed_func.h>
 
 #include <string>
index e3dbef5..e1a14c7 100644 (file)
@@ -103,13 +103,19 @@ class ROCMDeviceAPI final : public DeviceAPI {
         return;
       }
       case kMaxRegistersPerBlock:
-        return;
+        ROCM_CALL(
+            hipDeviceGetAttribute(&value, hipDeviceAttributeMaxRegistersPerBlock, ctx.device_id));
+        break;
       case kGcnArch: {
         hipDeviceProp_t prop;
         ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
         *rv = prop.gcnArch;
         return;
       }
+      case kApiVersion: {
+        *rv = HIP_VERSION;
+        return;
+      }
     }
     *rv = value;
   }
index ade4ddc..9e730b7 100644 (file)
@@ -417,6 +417,8 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue*
       return;
     case kGcnArch:
       return;
+    case kApiVersion:
+      return;
   }
 }
 
index 8e6b3a2..93c94cf 100644 (file)
@@ -108,11 +108,13 @@ class CodeGenAMDGPU : public CodeGenLLVM {
       llvm::GlobalVariable* global = new llvm::GlobalVariable(
           *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", nullptr,
           llvm::GlobalValue::NotThreadLocal, shared_address_space);
+      if (global->getAlignment() < static_cast<uint32_t>(info.alignment)) {
 #if TVM_LLVM_VERSION >= 100
-      global->setAlignment(llvm::Align(info.alignment));
+        global->setAlignment(llvm::Align(info.alignment));
 #else
-      global->setAlignment(info.alignment);
+        global->setAlignment(info.alignment);
 #endif
+      }
       buf = global;
     }
 
@@ -212,6 +214,20 @@ inline int DetectROCMComputeVersion(const std::string& target) {
   return 900;
 }
 
+inline int DetectROCMApiVersion() {
+  TVMContext tvm_ctx;
+  tvm_ctx.device_type = kDLROCM;
+  tvm_ctx.device_id = 0;
+  tvm::runtime::DeviceAPI* api = tvm::runtime::DeviceAPI::Get(tvm_ctx, true);
+  if (api != nullptr) {
+    TVMRetValue val;
+    api->GetAttr(tvm_ctx, tvm::runtime::kApiVersion, &val);
+    return val.operator int();
+  }
+  LOG(WARNING) << "Cannot detect ROCm version, assume >= 3.5";
+  return 305;
+}
+
 runtime::Module BuildAMDGPU(IRModule mod, std::string target) {
 #if TVM_LLVM_VERSION < 90
   LOG(FATAL) << "AMDGPU backend requires at least LLVM 9";
@@ -221,8 +237,13 @@ runtime::Module BuildAMDGPU(IRModule mod, std::string target) {
   InitializeLLVM();
   CHECK(target.length() >= 4 && target.substr(0, 4) == "rocm");
   std::ostringstream config;
-  config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" << DetectROCMComputeVersion(target)
-         << " -mattr=-code-object-v3 " << target.substr(4, target.length() - 4);
+  config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" << DetectROCMComputeVersion(target);
+  if (DetectROCMApiVersion() < 305) {
+    // before ROCm 3.5 we needed code object v2, starting
+    // with 3.5 we need v3 (this argument disables v3)
+    config << " -mattr=-code-object-v3 ";
+  }
+  config << target.substr(4, target.length() - 4);
   std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
   std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
   // careful: cg will hold a naked pointer reference to ctx, so it should