[NVPTX] Enforce minumum alignment of 4 for byval parametrs in a function prototype
authorPavel Kopyl <pavelkopyl@gmail.com>
Mon, 9 Jan 2023 14:59:47 +0000 (17:59 +0300)
committerAndrew Savonichev <andrew.savonichev@gmail.com>
Tue, 10 Jan 2023 12:22:40 +0000 (15:22 +0300)
As a result, we have identical alignment calculation of byval
parameters for:

  - LowerCall() - getting alignment of an argument (.param)

  - emitFunctionParamList() - getting alignment of a
    parameter (.param) in a function declaration

  - getPrototype() - getting alignment of a parameter (.param) in a
    function prototypes that is used for indirect calls

This change is required to avoid ptxas error: 'Alignment of argument
does not match formal parameter'. This error happens even in cases
where it logically shouldn't.

For instance:

  .param .align 4 .b8 param0[4];
  ...
  callprototype ()_ (.param .align 2 .b8 _[4]);
  ...

Here we allocate 'param0' with alignment of 4 and it should be fine to
pass it to a function that requires minimum alignment of 2.

At least ptxas v12.0 rejects this code.

Differential Revision: https://reviews.llvm.org/D140581

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/lib/Target/NVPTX/NVPTXISelLowering.h
llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
llvm/test/CodeGen/NVPTX/param-align.ll

index eab1954..dbf4bf4 100644 (file)
@@ -1612,21 +1612,12 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
       // <a>  = optimal alignment for the element type; always multiple of
       //        PAL.getParamAlignment
       // size = typeallocsize of element type
-      Align OptimalAlign = getOptimalAlignForParam(ETy);
-
-      // Work around a bug in ptxas. When PTX code takes address of
-      // byval parameter with alignment < 4, ptxas generates code to
-      // spill argument into memory. Alas on sm_50+ ptxas generates
-      // SASS code that fails with misaligned access. To work around
-      // the problem, make sure that we align byval parameters by at
-      // least 4. Matching change must be made in LowerCall() where we
-      // prepare parameters for the call.
-      //
-      // TODO: this will need to be undone when we get to support multi-TU
-      // device-side compilation as it breaks ABI compatibility with nvcc.
-      // Hopefully ptxas bug is fixed by then.
-      if (!isKernelFunc && OptimalAlign < Align(4))
-        OptimalAlign = Align(4);
+      Align OptimalAlign =
+          isKernelFunc
+              ? getOptimalAlignForParam(ETy)
+              : TLI->getFunctionByValParamAlign(
+                    F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL);
+
       unsigned sz = DL.getTypeAllocSize(ETy);
       O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
       printParamName(I, paramIndex, O);
index b7e81ea..6206670 100644 (file)
@@ -1414,13 +1414,10 @@ std::string NVPTXTargetLowering::getPrototype(
       continue;
     }
 
-    Align ParamByValAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
-
-    // Try to increase alignment. This code matches logic in LowerCall when
-    // alignment increase is performed to increase vectorization options.
     Type *ETy = Args[i].IndirectType;
-    Align AlignCandidate = getFunctionParamOptimizedAlign(F, ETy, DL);
-    ParamByValAlign = std::max(ParamByValAlign, AlignCandidate);
+    Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
+    Align ParamByValAlign =
+        getFunctionByValParamAlign(F, ETy, InitialAlign, DL);
 
     O << ".param .align " << ParamByValAlign.value() << " .b8 ";
     O << "_";
@@ -1560,17 +1557,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
       // so we don't need to worry whether it's naturally aligned or not.
       // See TargetLowering::LowerCallTo().
-      ArgAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
-
-      // Try to increase alignment to enhance vectorization options.
-      if (const Function *DirectCallee = CB->getCalledFunction())
-        ArgAlign = std::max(
-            ArgAlign, getFunctionParamOptimizedAlign(DirectCallee, ETy, DL));
-
-      // Enforce minumum alignment of 4 to work around ptxas miscompile
-      // for sm_50+. See corresponding alignment adjustment in
-      // emitFunctionParamList() for details.
-      ArgAlign = std::max(ArgAlign, Align(4));
+      Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
+      ArgAlign = getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
+                                            InitialAlign, DL);
       if (IsVAArg)
         VAOffset = alignTo(VAOffset, ArgAlign);
     } else {
@@ -4510,6 +4499,29 @@ Align NVPTXTargetLowering::getFunctionParamOptimizedAlign(
   return Align(std::max(uint64_t(16), ABITypeAlign));
 }
 
+/// Helper for computing alignment of a device function byval parameter.
+Align NVPTXTargetLowering::getFunctionByValParamAlign(
+    const Function *F, Type *ArgTy, Align InitialAlign,
+    const DataLayout &DL) const {
+  Align ArgAlign = InitialAlign;
+  // Try to increase alignment to enhance vectorization options.
+  if (F)
+    ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign(F, ArgTy, DL));
+
+  // Work around a bug in ptxas. When PTX code takes address of
+  // byval parameter with alignment < 4, ptxas generates code to
+  // spill argument into memory. Alas on sm_50+ ptxas generates
+  // SASS code that fails with misaligned access. To work around
+  // the problem, make sure that we align byval parameters by at
+  // least 4.
+  // TODO: this will need to be undone when we get to support multi-TU
+  // device-side compilation as it breaks ABI compatibility with nvcc.
+  // Hopefully ptxas bug is fixed by then.
+  ArgAlign = std::max(ArgAlign, Align(4));
+
+  return ArgAlign;
+}
+
 /// isLegalAddressingMode - Return true if the addressing mode represented
 /// by AM is legal for this target, for a load/store of the specified type.
 /// Used to guide target specific optimizations, like loop strength reduction
index 78d8231..f48ec17 100644 (file)
@@ -461,6 +461,11 @@ public:
   Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy,
                                        const DataLayout &DL) const;
 
+  /// Helper for computing alignment of a device function byval parameter.
+  Align getFunctionByValParamAlign(const Function *F, Type *ArgTy,
+                                   Align InitialAlign,
+                                   const DataLayout &DL) const;
+
   /// isLegalAddressingMode - Return true if the addressing mode represented
   /// by AM is legal for this target, for a load/store of the specified type
   /// Used to guide target specific optimizations, like loop strength
index 7a0d78c..a743b14 100644 (file)
@@ -37,9 +37,9 @@ define void @boom() {
   %fp = call ptr @usefp(ptr @callee)
   ; CHECK: .param .align 4 .b8 param0[4];
   ; CHECK: st.param.v2.b16 [param0+0]
-  ; CHECK: .callprototype ()_ (.param .align 2 .b8 _[4]);
+  ; CHECK: .callprototype ()_ (.param .align 4 .b8 _[4]);
   call void %fp(ptr byval(%"class.complex") null)
   ret void
 }
 
-declare %complex_half @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE()
+declare %complex_half @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE(i32, i32, ptr byval(%"class.complex"))
index 022a750..f3d3003 100644 (file)
@@ -43,3 +43,24 @@ define ptx_device void @t5(ptr align 2 byval(i8) %x) {
   call void @t4(ptr byval(i8) %x)
   ret void
 }
+
+;;; Make sure we adjust alignment for a function prototype
+;;; in case of an inderect call.
+
+declare ptr @getfp(i32 %n)
+%struct.half2 = type { half, half }
+define ptx_device void @t6() {
+; CHECK: .func t6
+  %fp = call ptr @getfp(i32 0)
+; CHECK: prototype_2 : .callprototype ()_ (.param .align 8 .b8 _[8]);
+  call void %fp(ptr byval(double) null);
+
+  %fp2 = call ptr @getfp(i32 1)
+; CHECK: prototype_4 : .callprototype ()_ (.param .align 4 .b8 _[4]);
+  call void %fp(ptr byval(%struct.half2) null);
+
+  %fp3 = call ptr @getfp(i32 2)
+; CHECK: prototype_6 : .callprototype ()_ (.param .align 4 .b8 _[1]);
+  call void %fp(ptr byval(i8) null);
+  ret void
+}