[NVPTX] Unforce minimum alignment of 4 for byval arguments of device-side functions.
authorPavel Kopyl <pavelkopyl@gmail.com>
Sat, 22 Apr 2023 00:52:04 +0000 (02:52 +0200)
committerPavel Kopyl <pavelkopyl@gmail.com>
Mon, 24 Apr 2023 22:18:16 +0000 (00:18 +0200)
Minimum alignment of 4 for byval arguments was forced to workaround
a bug in old versions of ptxas. Details: https://reviews.llvm.org/D22428.
Recent ptxas versions (> 9.0) do not seem to have this bug, so alignment
requirement was relaxed. To force again minimum alignment of 4, use
'-force-min-byval-param-align' option.

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
llvm/test/CodeGen/NVPTX/param-align.ll

index ec96e3c..aa34b2e 100644 (file)
@@ -89,6 +89,12 @@ static cl::opt<bool> UsePrecSqrtF32(
     cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."),
     cl::init(true));
 
+static cl::opt<bool> ForceMinByValParamAlign(
+    "nvptx-force-min-byval-param-align", cl::Hidden,
+    cl::desc("NVPTX Specific: force 4-byte minimal alignment for byval"
+             " params of device functions."),
+    cl::init(false));
+
 int NVPTXTargetLowering::getDivF32Level() const {
   if (UsePrecDivF32.getNumOccurrences() > 0) {
     // If nvptx-prec-div32=N is used on the command-line, always honor it
@@ -4502,16 +4508,17 @@ Align NVPTXTargetLowering::getFunctionByValParamAlign(
   if (F)
     ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign(F, ArgTy, DL));
 
-  // Work around a bug in ptxas. When PTX code takes address of
+  // Old ptx versions have a bug. When PTX code takes address of
   // byval parameter with alignment < 4, ptxas generates code to
   // spill argument into memory. Alas on sm_50+ ptxas generates
   // SASS code that fails with misaligned access. To work around
   // the problem, make sure that we align byval parameters by at
-  // least 4.
-  // TODO: this will need to be undone when we get to support multi-TU
-  // device-side compilation as it breaks ABI compatibility with nvcc.
-  // Hopefully ptxas bug is fixed by then.
-  ArgAlign = std::max(ArgAlign, Align(4));
+  // least 4. This bug seems to be fixed at least starting from
+  // ptxas > 9.0.
+  // TODO: remove this after verifying the bug is not reproduced
+  // on non-deprecated ptxas versions.
+  if (ForceMinByValParamAlign)
+    ArgAlign = std::max(ArgAlign, Align(4));
 
   return ArgAlign;
 }
index a743b14..0a411f4 100644 (file)
@@ -13,8 +13,9 @@ target triple = "nvptx64-nvidia-cuda"
 %"class.sycl::_V1::detail::half_impl::half" = type { half }
 %complex_half = type { half, half }
 
-; CHECK: .param .align 4 .b8 param2[4];
-; CHECK: st.param.v2.b16         [param2+0], {%h2, %h1};
+; CHECK: .param .align 2 .b8 param2[4];
+; CHECK: st.param.b16   [param2+0], %h1;
+; CHECK: st.param.b16   [param2+2], %h2;
 ; CHECK: .param .align 2 .b8 retval0[4];
 ; CHECK: call.uni (retval0),
 ; CHECK-NEXT: _Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE,
@@ -29,15 +30,16 @@ entry:
 ;;
 declare ptr @usefp(ptr %fp)
 ; CHECK: .func callee(
-; CHECK-NEXT: .param .align 4 .b8 callee_param_0[4]
+; CHECK-NEXT: .param .align 2 .b8 callee_param_0[4]
 define internal void @callee(ptr byval(%"class.complex") %byval_arg) {
   ret void
 }
 define void @boom() {
   %fp = call ptr @usefp(ptr @callee)
-  ; CHECK: .param .align 4 .b8 param0[4];
-  ; CHECK: st.param.v2.b16 [param0+0]
-  ; CHECK: .callprototype ()_ (.param .align 4 .b8 _[4]);
+  ; CHECK: .param .align 2 .b8 param0[4];
+  ; CHECK: st.param.b16 [param0+0], %h1;
+  ; CHECK: st.param.b16 [param0+2], %h2;
+  ; CHECK: .callprototype ()_ (.param .align 2 .b8 _[4]);
   call void %fp(ptr byval(%"class.complex") null)
   ret void
 }
index f3d3003..5f5d77b 100644 (file)
@@ -1,5 +1,7 @@
-; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
+; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s --check-prefixes=CHECK,NOALIGN4
+; RUN: llc < %s -march=nvptx -mcpu=sm_20 -nvptx-force-min-byval-param-align | FileCheck %s --check-prefixes=CHECK,ALIGN4
 ; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_20 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_20 -nvptx-force-min-byval-param-align | %ptxas-verify %}
 
 ;;; Need 4-byte alignment on ptr passed byval
 define ptx_device void @t1(ptr byval(float) %x) {
@@ -25,20 +27,21 @@ define ptx_device void @t3(ptr byval(%struct.float2) %x) {
   ret void
 }
 
-;;; Need at least 4-byte alignment in order to avoid miscompilation by
-;;; ptxas for sm_50+
 define ptx_device void @t4(ptr byval(i8) %x) {
 ; CHECK: .func t4
-; CHECK: .param .align 4 .b8 t4_param_0[1]
+; NOALIGN4: .param .align 1 .b8 t4_param_0[1]
+; ALIGN4: .param .align 4 .b8 t4_param_0[1]
   ret void
 }
 
 ;;; Make sure we adjust alignment at the call site as well.
 define ptx_device void @t5(ptr align 2 byval(i8) %x) {
 ; CHECK: .func t5
-; CHECK: .param .align 4 .b8 t5_param_0[1]
+; NOALIGN4: .param .align 2 .b8 t5_param_0[1]
+; ALIGN4: .param .align 4 .b8 t5_param_0[1]
 ; CHECK: {
-; CHECK: .param .align 4 .b8 param0[1];
+; NOALIGN4: .param .align 1 .b8 param0[1];
+; ALIGN4:   .param .align 4 .b8 param0[1];
 ; CHECK: call.uni
   call void @t4(ptr byval(i8) %x)
   ret void
@@ -56,11 +59,13 @@ define ptx_device void @t6() {
   call void %fp(ptr byval(double) null);
 
   %fp2 = call ptr @getfp(i32 1)
-; CHECK: prototype_4 : .callprototype ()_ (.param .align 4 .b8 _[4]);
+; NOALIGN4: prototype_4 : .callprototype ()_ (.param .align 2 .b8 _[4]);
+; ALIGN4: prototype_4 : .callprototype ()_ (.param .align 4 .b8 _[4]);
   call void %fp(ptr byval(%struct.half2) null);
 
   %fp3 = call ptr @getfp(i32 2)
-; CHECK: prototype_6 : .callprototype ()_ (.param .align 4 .b8 _[1]);
+; NOALIGN4: prototype_6 : .callprototype ()_ (.param .align 1 .b8 _[1]);
+; ALIGN4: prototype_6 : .callprototype ()_ (.param .align 4 .b8 _[1]);
   call void %fp(ptr byval(i8) null);
   ret void
 }