From ec55dcedcec1cdf95d020307067bc871cb2b70e4 Mon Sep 17 00:00:00 2001 From: Matt Arsenault Date: Thu, 16 Sep 2021 17:49:50 -0400 Subject: [PATCH] AMDGPU: Refactor getWavesPerEU to separate flat workgroup size query Add an overload to pass the flat workgroup range in separately. This will allow the attributor to use the assumed value for amdgpu-flat-workgroup-sizes when inferring amdgpu-waves-per-eu. --- llvm/lib/Target/AMDGPU/AMDGPUSubtarget.cpp | 5 +---- llvm/lib/Target/AMDGPU/AMDGPUSubtarget.h | 13 ++++++++++++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.cpp b/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.cpp index 0094827..1873057 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.cpp @@ -533,13 +533,10 @@ std::pair AMDGPUSubtarget::getFlatWorkGroupSizes( } std::pair AMDGPUSubtarget::getWavesPerEU( - const Function &F) const { + const Function &F, std::pair FlatWorkGroupSizes) const { // Default minimum/maximum number of waves per execution unit. std::pair Default(1, getMaxWavesPerEU()); - // Default/requested minimum/maximum flat work group sizes. - std::pair FlatWorkGroupSizes = getFlatWorkGroupSizes(F); - // If minimum/maximum flat work group sizes were explicitly requested using // "amdgpu-flat-work-group-size" attribute, then set default minimum/maximum // number of waves per execution unit to values implied by requested diff --git a/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.h b/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.h index b160cdf..1d8a9e6 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.h @@ -91,7 +91,18 @@ public: /// be converted to integer, violate subtarget's specifications, or are not /// compatible with minimum/maximum number of waves limited by flat work group /// size, register usage, and/or lds usage. - std::pair getWavesPerEU(const Function &F) const; + std::pair getWavesPerEU(const Function &F) const { + // Default/requested minimum/maximum flat work group sizes. + std::pair FlatWorkGroupSizes = getFlatWorkGroupSizes(F); + return getWavesPerEU(F, FlatWorkGroupSizes); + } + + /// Overload which uses the specified values for the flat work group sizes, + /// rather than querying the function itself. \p FlatWorkGroupSizes Should + /// correspond to the function's value for getFlatWorkGroupSizes. + std::pair + getWavesPerEU(const Function &F, + std::pair FlatWorkGroupSizes) const; /// Return the amount of LDS that can be used that will not restrict the /// occupancy lower than WaveCount. -- 2.7.4