[libomptarget][amdgpu] Skip device_State allocation when using bss global
authorJon Chesterfield <jonathanchesterfield@gmail.com>
Sun, 6 Dec 2020 12:13:56 +0000 (12:13 +0000)
committerJon Chesterfield <jonathanchesterfield@gmail.com>
Sun, 6 Dec 2020 12:13:56 +0000 (12:13 +0000)
openmp/libomptarget/plugins/amdgpu/src/rtl.cpp

index ea8770e..e688ef7 100644 (file)
@@ -1033,54 +1033,64 @@ __tgt_target_table *__tgt_rtl_load_binary_locked(int32_t device_id,
 
   DP("ATMI module successfully loaded!\n");
 
-  // Zero the pseudo-bss variable by calling into hsa
-  // Do this post-load to handle got
-  uint64_t device_State_bytes =
-      get_device_State_bytes((char *)image->ImageStart, img_size);
-  auto &dss = DeviceInfo.deviceStateStore[device_id];
-  if (device_State_bytes != 0) {
-
-    if (dss.first.get() == nullptr) {
-      assert(dss.second == 0);
-      void *ptr = NULL;
-      atmi_status_t err =
-          atmi_calloc(&ptr, device_State_bytes, get_gpu_mem_place(device_id));
-      if (err != ATMI_STATUS_SUCCESS) {
-        fprintf(stderr, "Failed to allocate device_state array\n");
-        return NULL;
-      }
-      dss = {std::unique_ptr<void, RTLDeviceInfoTy::atmiFreePtrDeletor>{ptr},
-             device_State_bytes};
-    }
-
-    void *ptr = dss.first.get();
-    if (device_State_bytes != dss.second) {
-      fprintf(stderr, "Inconsistent sizes of device_State unsupported\n");
-      exit(1);
-    }
+  {
+    // the device_State array is either large value in bss or a void* that
+    // needs to be assigned to a pointer to an array of size device_state_bytes
 
     void *state_ptr;
     uint32_t state_ptr_size;
-    err = atmi_interop_hsa_get_symbol_info(get_gpu_mem_place(device_id),
-                                           "omptarget_nvptx_device_State",
-                                           &state_ptr, &state_ptr_size);
+    atmi_status_t err = atmi_interop_hsa_get_symbol_info(
+        get_gpu_mem_place(device_id), "omptarget_nvptx_device_State",
+        &state_ptr, &state_ptr_size);
 
     if (err != ATMI_STATUS_SUCCESS) {
-      fprintf(stderr, "failed to find device_state ptr\n");
+      fprintf(stderr, "failed to find device_state symbol\n");
       return NULL;
     }
-    if (state_ptr_size != sizeof(void *)) {
+
+    if (state_ptr_size < sizeof(void *)) {
       fprintf(stderr, "unexpected size of state_ptr %u != %zu\n",
               state_ptr_size, sizeof(void *));
       return NULL;
     }
 
-    // write ptr to device memory so it can be used by later kernels
-    err = DeviceInfo.freesignalpool_memcpy_h2d(state_ptr, &ptr, sizeof(void *),
-                                               device_id);
-    if (err != ATMI_STATUS_SUCCESS) {
-      fprintf(stderr, "memcpy install of state_ptr failed\n");
-      return NULL;
+    // if it's larger than a void*, assume it's a bss array and no further
+    // initialization is required. Only try to set up a pointer for
+    // sizeof(void*)
+    if (state_ptr_size == sizeof(void *)) {
+      uint64_t device_State_bytes =
+          get_device_State_bytes((char *)image->ImageStart, img_size);
+      if (device_State_bytes == 0) {
+        return NULL;
+      }
+
+      auto &dss = DeviceInfo.deviceStateStore[device_id];
+      if (dss.first.get() == nullptr) {
+        assert(dss.second == 0);
+        void *ptr = NULL;
+        atmi_status_t err =
+            atmi_calloc(&ptr, device_State_bytes, get_gpu_mem_place(device_id));
+        if (err != ATMI_STATUS_SUCCESS) {
+          fprintf(stderr, "Failed to allocate device_state array\n");
+          return NULL;
+        }
+        dss = {std::unique_ptr<void, RTLDeviceInfoTy::atmiFreePtrDeletor>{ptr},
+               device_State_bytes};
+      }
+
+      void *ptr = dss.first.get();
+      if (device_State_bytes != dss.second) {
+        fprintf(stderr, "Inconsistent sizes of device_State unsupported\n");
+        exit(1);
+      }
+
+      // write ptr to device memory so it can be used by later kernels
+      err = DeviceInfo.freesignalpool_memcpy_h2d(state_ptr, &ptr,
+                                                 sizeof(void *), device_id);
+      if (err != ATMI_STATUS_SUCCESS) {
+        fprintf(stderr, "memcpy install of state_ptr failed\n");
+        return NULL;
+      }
     }
   }