[OpenMP][CUDA] Fix the issue that P2P memcpy doesn't work
authorShilei Tian <i@tianshilei.me>
Tue, 28 Jun 2022 19:31:41 +0000 (15:31 -0400)
committerShilei Tian <i@tianshilei.me>
Tue, 28 Jun 2022 19:32:03 +0000 (15:32 -0400)
This patch fixes the issue that P2P memcpy doesn't work. The root cause is we didn't set current context when calling the API function. In addition, a matrix to track the states of each pair of devices is also added such that we only need to query and configure the device once.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D122764

openmp/libomptarget/plugins/cuda/src/rtl.cpp

index d6cf166..99739f4 100644 (file)
@@ -355,6 +355,10 @@ class DeviceRTLTy {
   /// devices.
   std::vector<bool> InitializedFlags;
 
+  enum class PeerAccessState : uint8_t { Unkown, Yes, No };
+  std::vector<std::vector<PeerAccessState>> PeerAccessMatrix;
+  std::mutex PeerAccessMatrixLock;
+
   /// A class responsible for interacting with device native runtime library to
   /// allocate and free memory.
   class CUDADeviceAllocatorTy : public DeviceAllocatorTy {
@@ -520,6 +524,9 @@ public:
     Modules.resize(NumberOfDevices);
     StreamPool.resize(NumberOfDevices);
     EventPool.resize(NumberOfDevices);
+    PeerAccessMatrix.resize(NumberOfDevices);
+    for (auto &V : PeerAccessMatrix)
+      V.resize(NumberOfDevices, PeerAccessState::Unkown);
 
     // Get environment variables regarding teams
     if (const char *EnvStr = getenv("OMP_TEAM_LIMIT")) {
@@ -1015,7 +1022,7 @@ public:
   }
 
   int dataExchange(int SrcDevId, const void *SrcPtr, int DstDevId, void *DstPtr,
-                   int64_t Size, __tgt_async_info *AsyncInfo) const {
+                   int64_t Size, __tgt_async_info *AsyncInfo) {
     assert(AsyncInfo && "AsyncInfo is nullptr");
 
     CUresult Err;
@@ -1023,40 +1030,69 @@ public:
 
     // If they are two devices, we try peer to peer copy first
     if (SrcDevId != DstDevId) {
-      int CanAccessPeer = 0;
-      Err = cuDeviceCanAccessPeer(&CanAccessPeer, SrcDevId, DstDevId);
-      if (Err != CUDA_SUCCESS) {
-        REPORT("Error returned from cuDeviceCanAccessPeer. src = %" PRId32
-               ", dst = %" PRId32 "\n",
+      std::lock_guard<std::mutex> LG(PeerAccessMatrixLock);
+
+      switch (PeerAccessMatrix[SrcDevId][DstDevId]) {
+      case PeerAccessState::No: {
+        REPORT("Peer access from %" PRId32 " to %" PRId32
+               " is not supported. Fall back to D2D memcpy.\n",
                SrcDevId, DstDevId);
-        CUDA_ERR_STRING(Err);
         return memcpyDtoD(SrcPtr, DstPtr, Size, Stream);
       }
+      case PeerAccessState::Unkown: {
+        int CanAccessPeer = 0;
+        Err = cuDeviceCanAccessPeer(&CanAccessPeer, SrcDevId, DstDevId);
+        if (Err != CUDA_SUCCESS) {
+          REPORT("Error returned from cuDeviceCanAccessPeer. src = %" PRId32
+                 ", dst = %" PRId32 ". Fall back to D2D memcpy.\n",
+                 SrcDevId, DstDevId);
+          CUDA_ERR_STRING(Err);
+          PeerAccessMatrix[SrcDevId][DstDevId] = PeerAccessState::No;
+          return memcpyDtoD(SrcPtr, DstPtr, Size, Stream);
+        }
 
-      if (!CanAccessPeer) {
-        DP("P2P memcpy not supported so fall back to D2D memcpy");
-        return memcpyDtoD(SrcPtr, DstPtr, Size, Stream);
-      }
+        if (!CanAccessPeer) {
+          REPORT("P2P access from %d to %d is not supported. Fall back to D2D "
+                 "memcpy.\n",
+                 SrcDevId, DstDevId);
+          PeerAccessMatrix[SrcDevId][DstDevId] = PeerAccessState::No;
+          return memcpyDtoD(SrcPtr, DstPtr, Size, Stream);
+        }
 
-      Err = cuCtxEnablePeerAccess(DeviceData[DstDevId].Context, 0);
-      if (Err != CUDA_SUCCESS) {
-        REPORT("Error returned from cuCtxEnablePeerAccess. src = %" PRId32
-               ", dst = %" PRId32 "\n",
-               SrcDevId, DstDevId);
+        Err = cuCtxEnablePeerAccess(DeviceData[DstDevId].Context, 0);
+        if (Err != CUDA_SUCCESS) {
+          REPORT("Error returned from cuCtxEnablePeerAccess. src = %" PRId32
+                 ", dst = %" PRId32 ". Fall back to D2D memcpy.\n",
+                 SrcDevId, DstDevId);
+          CUDA_ERR_STRING(Err);
+          PeerAccessMatrix[SrcDevId][DstDevId] = PeerAccessState::No;
+          return memcpyDtoD(SrcPtr, DstPtr, Size, Stream);
+        }
+
+        PeerAccessMatrix[SrcDevId][DstDevId] = PeerAccessState::Yes;
+
+        LLVM_FALLTHROUGH;
+      }
+      case PeerAccessState::Yes: {
+        Err = cuMemcpyPeerAsync(
+            (CUdeviceptr)DstPtr, DeviceData[DstDevId].Context,
+            (CUdeviceptr)SrcPtr, DeviceData[SrcDevId].Context, Size, Stream);
+        if (Err == CUDA_SUCCESS)
+          return OFFLOAD_SUCCESS;
+
+        DP("Error returned from cuMemcpyPeerAsync. src_ptr = " DPxMOD
+           ", src_id =%" PRId32 ", dst_ptr = " DPxMOD ", dst_id =%" PRId32
+           ". Fall back to D2D memcpy.\n",
+           DPxPTR(SrcPtr), SrcDevId, DPxPTR(DstPtr), DstDevId);
         CUDA_ERR_STRING(Err);
+
         return memcpyDtoD(SrcPtr, DstPtr, Size, Stream);
       }
-
-      Err = cuMemcpyPeerAsync((CUdeviceptr)DstPtr, DeviceData[DstDevId].Context,
-                              (CUdeviceptr)SrcPtr, DeviceData[SrcDevId].Context,
-                              Size, Stream);
-      if (Err == CUDA_SUCCESS)
-        return OFFLOAD_SUCCESS;
-
-      DP("Error returned from cuMemcpyPeerAsync. src_ptr = " DPxMOD
-         ", src_id =%" PRId32 ", dst_ptr = " DPxMOD ", dst_id =%" PRId32 "\n",
-         DPxPTR(SrcPtr), SrcDevId, DPxPTR(DstPtr), DstDevId);
-      CUDA_ERR_STRING(Err);
+      default:
+        REPORT("Unknown PeerAccessState %d.\n",
+               int(PeerAccessMatrix[SrcDevId][DstDevId]));
+        return OFFLOAD_FAIL;
+      }
     }
 
     return memcpyDtoD(SrcPtr, DstPtr, Size, Stream);
@@ -1598,8 +1634,10 @@ int32_t __tgt_rtl_data_exchange_async(int32_t src_dev_id, void *src_ptr,
   assert(DeviceRTL.isValidDeviceId(src_dev_id) && "src_dev_id is invalid");
   assert(DeviceRTL.isValidDeviceId(dst_dev_id) && "dst_dev_id is invalid");
   assert(AsyncInfo && "AsyncInfo is nullptr");
-  // NOTE: We don't need to set context for data exchange as the device contexts
-  // are passed to CUDA function directly.
+
+  if (DeviceRTL.setContext(src_dev_id) != OFFLOAD_SUCCESS)
+    return OFFLOAD_FAIL;
+
   return DeviceRTL.dataExchange(src_dev_id, src_ptr, dst_dev_id, dst_ptr, size,
                                 AsyncInfo);
 }