[NVPTX] Add initial support for '.alias' in PTX
authorJoseph Huber <jhuber6@vols.utk.edu>
Wed, 12 Jul 2023 18:35:01 +0000 (13:35 -0500)
committerJoseph Huber <jhuber6@vols.utk.edu>
Fri, 21 Jul 2023 21:43:46 +0000 (16:43 -0500)
This patch adds initial support for using aliases when targeting PTX. We
perform a pretty strict conversion from the globals referenced to the
expected output. as described in the PTX documentation at
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#kernel-and-function-directives-alias

These cannot currently be used due to a bug in the `nvlink`
implementation that causes aliases to pruned functions to crash the
linker.

Reviewed By: tra

Differential Revision: https://reviews.llvm.org/D155211

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
llvm/test/CodeGen/NVPTX/alias-errors.ll [new file with mode: 0644]
llvm/test/CodeGen/NVPTX/alias.ll

index fd03267..e9239fa 100644 (file)
@@ -473,6 +473,7 @@ void NVPTXAsmPrinter::emitFunctionEntryLabel() {
   CurrentFnSym->print(O, MAI);
 
   emitFunctionParamList(F, O);
+  O << "\n";
 
   if (isKernelFunction(*F))
     emitKernelFunctionDirectives(*F, O);
@@ -623,6 +624,7 @@ void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
   getSymbol(F)->print(O, MAI);
   O << "\n";
   emitFunctionParamList(F, O);
+  O << "\n";
   if (shouldEmitPTXNoReturn(F, TM))
     O << ".noreturn";
   O << ";\n";
@@ -790,10 +792,12 @@ void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
 }
 
 bool NVPTXAsmPrinter::doInitialization(Module &M) {
-  if (M.alias_size()) {
-    report_fatal_error("Module has aliases, which NVPTX does not support.");
-    return true; // error
-  }
+  const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
+  const NVPTXSubtarget &STI =
+      *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
+  if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30))
+    report_fatal_error(".alias requires PTX version >= 6.3 and sm_30");
+
   if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors")) &&
       !LowerCtorDtor) {
     report_fatal_error(
@@ -850,6 +854,32 @@ void NVPTXAsmPrinter::emitGlobals(const Module &M) {
   OutStreamer->emitRawText(OS2.str());
 }
 
+void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) {
+  SmallString<128> Str;
+  raw_svector_ostream OS(Str);
+
+  MCSymbol *Name = getSymbol(&GA);
+  const Function *F = dyn_cast<Function>(GA.getAliasee());
+  if (!F || isKernelFunction(*F))
+    report_fatal_error("NVPTX aliasee must be a non-kernel function");
+
+  if (GA.hasLinkOnceLinkage() || GA.hasWeakLinkage() ||
+      GA.hasAvailableExternallyLinkage() || GA.hasCommonLinkage())
+    report_fatal_error("NVPTX aliasee must not be '.weak'");
+
+  OS << "\n";
+  emitLinkageDirective(F, OS);
+  OS << ".func ";
+  printReturnValStr(F, OS);
+  OS << Name->getName();
+  emitFunctionParamList(F, OS);
+  OS << ";\n";
+
+  OS << ".alias " << Name->getName() << ", " << F->getName() << ";\n";
+
+  OutStreamer->emitRawText(OS.str());
+}
+
 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
                                  const NVPTXSubtarget &STI) {
   O << "//\n";
@@ -906,6 +936,16 @@ bool NVPTXAsmPrinter::doFinalization(Module &M) {
     GlobalsEmitted = true;
   }
 
+  // If we have any aliases we emit them at the end.
+  SmallVector<GlobalAlias *> AliasesToRemove;
+  for (GlobalAlias &Alias : M.aliases()) {
+    emitGlobalAlias(M, Alias);
+    AliasesToRemove.push_back(&Alias);
+  }
+
+  for (GlobalAlias *A : AliasesToRemove)
+    A->eraseFromParent();
+
   // call doFinalization
   bool ret = AsmPrinter::doFinalization(M);
 
@@ -1465,7 +1505,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
   bool hasImageHandles = STI.hasImageHandles();
 
   if (F->arg_empty() && !F->isVarArg()) {
-    O << "()\n";
+    O << "()";
     return;
   }
 
@@ -1659,7 +1699,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
     O << TLI->getParamName(F, /* vararg */ -1) << "[]";
   }
 
-  O << "\n)\n";
+  O << "\n)";
 }
 
 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
index 673aad1..2bd4011 100644 (file)
@@ -174,6 +174,7 @@ private:
   void printModuleLevelGV(const GlobalVariable *GVar, raw_ostream &O,
                           bool processDemoted, const NVPTXSubtarget &STI);
   void emitGlobals(const Module &M);
+  void emitGlobalAlias(const Module &M, const GlobalAlias &GA);
   void emitHeader(Module &M, raw_ostream &O, const NVPTXSubtarget &STI);
   void emitKernelFunctionDirectives(const Function &F, raw_ostream &O) const;
   void emitVirtualRegister(unsigned int vr, raw_ostream &);
diff --git a/llvm/test/CodeGen/NVPTX/alias-errors.ll b/llvm/test/CodeGen/NVPTX/alias-errors.ll
new file mode 100644 (file)
index 0000000..0db3b3a
--- /dev/null
@@ -0,0 +1,9 @@
+; RUN: not --crash llc < %s -march=nvptx64 -mcpu=sm_30 -mattr=+ptx43 2>&1 | FileCheck %s --check-prefix=ATTR
+; RUN: not --crash llc < %s -march=nvptx64 -mcpu=sm_20 -mattr=+ptx63 2>&1 | FileCheck %s --check-prefix=ATTR
+; RUN: not --crash llc < %s -march=nvptx64 -mcpu=sm_30 -mattr=+ptx63 2>&1 | FileCheck %s --check-prefix=ALIAS
+
+; ATTR: .alias requires PTX version >= 6.3 and sm_30
+
+; ALIAS: NVPTX aliasee must be a non-kernel function
+@a = global i32 42, align 8
+@b = internal alias i32, ptr @a
index 6124a7c..3c23133 100644 (file)
@@ -1,7 +1,27 @@
-; RUN: not --crash llc < %s -march=nvptx -mcpu=sm_20 2>&1 | FileCheck %s
-
-; Check that llc dies gracefully when given an alias.
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_30 -mattr=+ptx63 | FileCheck %s
 
 define i32 @a() { ret i32 0 }
-; CHECK: ERROR: Module has aliases
 @b = internal alias i32 (), ptr @a
+@c = internal alias i32 (), ptr @a
+
+define void @foo(i32 %0, ptr %1) { ret void }
+@bar = alias i32 (), ptr @foo
+
+; CHECK: .visible .func  (.param .b32 func_retval0) a()
+
+;      CHECK: .visible .func foo(
+; CHECK-NEXT:         .param .b32 foo_param_0,
+; CHECK-NEXT:         .param .b64 foo_param_1
+; CHECK-NEXT: )
+
+;      CHECK: .visible .func  (.param .b32 func_retval0) b();
+; CHECK-NEXT: .alias b, a;
+
+;      CHECK: .visible .func  (.param .b32 func_retval0) c();
+; CHECK-NEXT: .alias c, a;
+
+;      CHECK: .visible .func bar(
+; CHECK-NEXT:         .param .b32 foo_param_0,
+; CHECK-NEXT:         .param .b64 foo_param_1
+; CHECK-NEXT: );
+; CHECK-NEXT: .alias bar, foo;