[OpenMP] Replace pointer comparison with `isSharedMemPtr` check
authorJohannes Doerfert <johannes@jdoerfert.de>
Tue, 4 Oct 2022 12:50:45 +0000 (05:50 -0700)
committerJohannes Doerfert <johannes@jdoerfert.de>
Wed, 5 Oct 2022 02:24:22 +0000 (19:24 -0700)
The pointer comparison was causing confusion for capture tracking, let's
avoid confusion.

Differential Revision: https://reviews.llvm.org/D135160

openmp/libomptarget/DeviceRTL/include/Utils.h
openmp/libomptarget/DeviceRTL/src/State.cpp
openmp/libomptarget/DeviceRTL/src/Utils.cpp

index 9e71b51..178c001 100644 (file)
@@ -74,6 +74,9 @@ template <typename Ty1, typename Ty2> inline Ty1 align_down(Ty1 V, Ty2 Align) {
   return V - V % Align;
 }
 
+/// Return true iff \p Ptr is pointing into shared (local) memory (AS(3)).
+bool isSharedMemPtr(void *Ptr);
+
 /// A  pointer variable that has by design an `undef` value. Use with care.
 __attribute__((loader_uninitialized)) static void *const UndefPtr;
 
index 7a73330..59e6f48 100644 (file)
@@ -14,6 +14,7 @@
 #include "Interface.h"
 #include "Synchronization.h"
 #include "Types.h"
+#include "Utils.h"
 
 using namespace _OMP;
 
@@ -147,7 +148,7 @@ void *SharedMemorySmartStackTy::push(uint64_t Bytes) {
 
 void SharedMemorySmartStackTy::pop(void *Ptr, uint32_t Bytes) {
   uint64_t AlignedBytes = utils::align_up(Bytes, Alignment);
-  if (Ptr >= &Data[0] && Ptr < &Data[state::SharedScratchpadSize]) {
+  if (utils::isSharedMemPtr(Ptr)) {
     int TId = mapping::getThreadIdInBlock();
     Usage[TId] -= AlignedBytes;
     return;
index 453d131..2aa0194 100644 (file)
@@ -32,6 +32,7 @@ __attribute__((weak, optnone, cold)) KEEP_ALIVE void keepAlive() {
 
 namespace impl {
 
+bool isSharedMemPtr(const void *Ptr) { return false; }
 void Unpack(uint64_t Val, uint32_t *LowBits, uint32_t *HighBits);
 uint64_t Pack(uint32_t LowBits, uint32_t HighBits);
 
@@ -51,6 +52,7 @@ uint64_t Pack(uint32_t LowBits, uint32_t HighBits) {
 }
 
 #pragma omp end declare variant
+///}
 
 /// NVPTX Implementation
 ///
@@ -74,6 +76,7 @@ uint64_t Pack(uint32_t LowBits, uint32_t HighBits) {
 }
 
 #pragma omp end declare variant
+///}
 
 int32_t shuffle(uint64_t Mask, int32_t Var, int32_t SrcLane);
 int32_t shuffleDown(uint64_t Mask, int32_t Var, uint32_t LaneDelta,
@@ -99,6 +102,9 @@ int32_t shuffleDown(uint64_t Mask, int32_t Var, uint32_t LaneDelta,
   return __builtin_amdgcn_ds_bpermute(Index << 2, Var);
 }
 
+bool isSharedMemPtr(const void * Ptr) {
+  return __builtin_amdgcn_is_shared((const __attribute__((address_space(0))) void *)Ptr);
+}
 #pragma omp end declare variant
 ///}
 
@@ -117,7 +123,10 @@ int32_t shuffleDown(uint64_t Mask, int32_t Var, uint32_t Delta, int32_t Width) {
   return __nvvm_shfl_sync_down_i32(Mask, Var, Delta, T);
 }
 
+bool isSharedMemPtr(const void *Ptr) { return __nvvm_isspacep_shared(Ptr); }
+
 #pragma omp end declare variant
+///}
 } // namespace impl
 
 uint64_t utils::pack(uint32_t LowBits, uint32_t HighBits) {
@@ -137,6 +146,8 @@ int32_t utils::shuffleDown(uint64_t Mask, int32_t Var, uint32_t Delta,
   return impl::shuffleDown(Mask, Var, Delta, Width);
 }
 
+bool utils::isSharedMemPtr(void *Ptr) { return impl::isSharedMemPtr(Ptr); }
+
 extern "C" {
 int32_t __kmpc_shuffle_int32(int32_t Val, int16_t Delta, int16_t SrcLane) {
   FunctionTracingRAII();