[OpenMP] Add thread limit environment variable support to plugins
authorJoseph Huber <jhuber6@vols.utk.edu>
Tue, 8 Jun 2021 19:12:24 +0000 (15:12 -0400)
committerHuber, Joseph <huberjn@ornl.gov>
Tue, 22 Jun 2021 20:25:40 +0000 (16:25 -0400)
The OpenMP 5.1 standard defines the environment variable
`OMP_TEAMS_THREAD_LIMIT` to limit the number of threads that will be run in a
single block. This patch adds support for this into the AMDGPU and CUDA
plugins.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D103923

openmp/libomptarget/plugins/amdgpu/src/rtl.cpp
openmp/libomptarget/plugins/cuda/src/rtl.cpp

index 8a4e12c..012bd1f 100644 (file)
@@ -405,6 +405,7 @@ public:
   // OpenMP Environment properties
   int EnvNumTeams;
   int EnvTeamLimit;
+  int EnvTeamThreadLimit;
   int EnvMaxTeamsDefault;
 
   // OpenMP Requires Flags
@@ -645,6 +646,13 @@ public:
     } else {
       EnvMaxTeamsDefault = -1;
     }
+    envStr = getenv("OMP_TEAMS_THREAD_LIMIT");
+    if (envStr) {
+      EnvTeamThreadLimit = std::stoi(envStr);
+      DP("Parsed OMP_TEAMS_THREAD_LIMIT=%d\n", EnvTeamThreadLimit);
+    } else {
+      EnvTeamThreadLimit = -1;
+    }
 
     // Default state.
     RequiresFlags = OMP_REQ_UNDEFINED;
@@ -950,6 +958,14 @@ int32_t __tgt_rtl_init_device(int device_id) {
        DeviceInfo.GroupsPerDevice[device_id]);
   }
 
+  // Adjust threads to the env variables
+  if (DeviceInfo.EnvTeamThreadLimit > 0 &&
+      (enforce_upper_bound(&DeviceInfo.NumThreads[device_id],
+                           DeviceInfo.EnvTeamThreadLimit))) {
+    DP("Capping max number of threads to OMP_TEAMS_THREAD_LIMIT=%d\n",
+       DeviceInfo.EnvTeamThreadLimit);
+  }
+
   // Set default number of threads
   DeviceInfo.NumThreads[device_id] = RTLDeviceInfoTy::Default_WG_Size;
   DP("Default number of threads set according to library's default %d\n",
index e8fe637..7b04bc9 100644 (file)
@@ -281,6 +281,7 @@ class DeviceRTLTy {
   // OpenMP environment properties
   int EnvNumTeams;
   int EnvTeamLimit;
+  int EnvTeamThreadLimit;
   // OpenMP requires flags
   int64_t RequiresFlags;
 
@@ -436,7 +437,7 @@ public:
 
   DeviceRTLTy()
       : NumberOfDevices(0), EnvNumTeams(-1), EnvTeamLimit(-1),
-        RequiresFlags(OMP_REQ_UNDEFINED) {
+        EnvTeamThreadLimit(-1), RequiresFlags(OMP_REQ_UNDEFINED) {
 
     DP("Start initializing CUDA\n");
 
@@ -467,6 +468,11 @@ public:
       EnvTeamLimit = std::stoi(EnvStr);
       DP("Parsed OMP_TEAM_LIMIT=%d\n", EnvTeamLimit);
     }
+    if (const char *EnvStr = getenv("OMP_TEAMS_THREAD_LIMIT")) {
+      // OMP_TEAMS_THREAD_LIMIT has been set
+      EnvTeamThreadLimit = std::stoi(EnvStr);
+      DP("Parsed OMP_TEAMS_THREAD_LIMIT=%d\n", EnvTeamThreadLimit);
+    }
     if (const char *EnvStr = getenv("OMP_NUM_TEAMS")) {
       // OMP_NUM_TEAMS has been set
       EnvNumTeams = std::stoi(EnvStr);
@@ -596,14 +602,23 @@ public:
       DP("Error getting max block dimension, use default value %d\n",
          DeviceRTLTy::DefaultNumThreads);
       DeviceData[DeviceId].ThreadsPerBlock = DeviceRTLTy::DefaultNumThreads;
-    } else if (MaxBlockDimX <= DeviceRTLTy::HardThreadLimit) {
+    } else {
       DP("Using %d CUDA threads per block\n", MaxBlockDimX);
       DeviceData[DeviceId].ThreadsPerBlock = MaxBlockDimX;
-    } else {
-      DP("Max CUDA threads per block %d exceeds the hard thread limit %d, "
-         "capping at the hard limit\n",
-         MaxBlockDimX, DeviceRTLTy::HardThreadLimit);
-      DeviceData[DeviceId].ThreadsPerBlock = DeviceRTLTy::HardThreadLimit;
+
+      if (EnvTeamThreadLimit > 0 &&
+          DeviceData[DeviceId].ThreadsPerBlock > EnvTeamThreadLimit) {
+        DP("Max CUDA threads per block %d exceeds the thread limit %d set by "
+           "OMP_TEAMS_THREAD_LIMIT, capping at the limit\n",
+           DeviceData[DeviceId].ThreadsPerBlock, EnvTeamThreadLimit);
+        DeviceData[DeviceId].ThreadsPerBlock = EnvTeamThreadLimit;
+      }
+      if (DeviceData[DeviceId].ThreadsPerBlock > DeviceRTLTy::HardThreadLimit) {
+        DP("Max CUDA threads per block %d exceeds the hard thread limit %d, "
+           "capping at the hard limit\n",
+           DeviceData[DeviceId].ThreadsPerBlock, DeviceRTLTy::HardThreadLimit);
+        DeviceData[DeviceId].ThreadsPerBlock = DeviceRTLTy::HardThreadLimit;
+      }
     }
 
     // Get and set warp size