ROCm: use GcnArch for mcpu and ApiVersion to select code object version (#6447)
authorThomas Viehmann <tv.code@beamnet.de>
Fri, 11 Sep 2020 06:11:44 +0000 (08:11 +0200)
committerGitHub <noreply@github.com>
Fri, 11 Sep 2020 06:11:44 +0000 (15:11 +0900)
src/target/target_kind.cc

index efb9d16..8cb6786 100644 (file)
@@ -172,19 +172,28 @@ Map<String, ObjectRef> UpdateROCmAttrs(Map<String, ObjectRef> attrs) {
     arch = ExtractIntWithPrefix(mcpu, "gfx");
     CHECK(arch != -1) << "ValueError: ROCm target gets an invalid GFX version: -mcpu=" << mcpu;
   } else {
-    TVMRetValue version;
-    if (!DetectDeviceFlag({kDLROCM, 0}, runtime::kApiVersion, &version)) {
-      LOG(WARNING) << "Unable to detect ROCm version, default to \"-mcpu=gfx305\" instead";
-      arch = 305;
+    TVMRetValue val;
+    if (!DetectDeviceFlag({kDLROCM, 0}, runtime::kGcnArch, &val)) {
+      LOG(WARNING) << "Unable to detect ROCm compute arch, default to \"-mcpu=gfx900\" instead";
+      arch = 900;
     } else {
-      arch = version.operator int();
+      arch = val.operator int();
     }
     attrs.Set("mcpu", String("gfx") + std::to_string(arch));
   }
   // Update -mattr before ROCm 3.5:
   //   Before ROCm 3.5 we needed code object v2, starting
   //   with 3.5 we need v3 (this argument disables v3)
-  if (arch < 305) {
+
+  TVMRetValue val;
+  int version;
+  if (!DetectDeviceFlag({kDLROCM, 0}, runtime::kApiVersion, &val)) {
+    LOG(WARNING) << "Unable to detect ROCm version, assuming >= 3.5";
+    version = 305;
+  } else {
+    version = val.operator int();
+  }
+  if (version < 305) {
     Array<String> mattr;
     if (attrs.count("mattr")) {
       mattr = Downcast<Array<String>>(attrs.at("mattr"));