[llvm][nvptx] add atomicity to counter in ISelLowering
authorTres Popp <tpopp@google.com>
Fri, 15 Jan 2021 16:11:41 +0000 (17:11 +0100)
committerTres Popp <tpopp@google.com>
Tue, 19 Jan 2021 09:20:20 +0000 (10:20 +0100)
Previously uniqueCallSite could have race conditions between different
threads. Now it is accessed with an atomic RMW and will be unique
between different threads.

Differential Revision: https://reviews.llvm.org/D94784

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/lib/Target/NVPTX/NVPTXISelLowering.h

index 05a0713..8860e90 100644 (file)
@@ -65,7 +65,7 @@
 
 using namespace llvm;
 
-static unsigned int uniqueCallSite = 0;
+static std::atomic<unsigned> GlobalUniqueCallSite;
 
 static cl::opt<bool> sched4reg(
     "nvptx-sched4reg",
@@ -1243,7 +1243,7 @@ NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
 std::string NVPTXTargetLowering::getPrototype(
     const DataLayout &DL, Type *retTy, const ArgListTy &Args,
     const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
-    const CallBase &CB) const {
+    const CallBase &CB, unsigned UniqueCallSite) const {
   auto PtrVT = getPointerTy(DL);
 
   bool isABI = (STI.getSmVersion() >= 20);
@@ -1252,7 +1252,7 @@ std::string NVPTXTargetLowering::getPrototype(
     return "";
 
   std::stringstream O;
-  O << "prototype_" << uniqueCallSite << " : .callprototype ";
+  O << "prototype_" << UniqueCallSite << " : .callprototype ";
 
   if (retTy->getTypeID() == Type::VoidTyID) {
     O << "()";
@@ -1422,8 +1422,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   if (!isABI)
     return Chain;
 
+  unsigned UniqueCallSite = GlobalUniqueCallSite.fetch_add(1);
   SDValue tempChain = Chain;
-  Chain = DAG.getCALLSEQ_START(Chain, uniqueCallSite, 0, dl);
+  Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl);
   SDValue InFlag = Chain.getValue(1);
 
   unsigned paramCount = 0;
@@ -1678,7 +1679,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     // The prototype is embedded in a string and put as the operand for a
     // CallPrototype SDNode which will print out to the value of the string.
     SDVTList ProtoVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-    std::string Proto = getPrototype(DL, RetTy, Args, Outs, retAlignment, *CB);
+    std::string Proto =
+        getPrototype(DL, RetTy, Args, Outs, retAlignment, *CB, UniqueCallSite);
     const char *ProtoStr =
       nvTM->getManagedStrPool()->getManagedString(Proto.c_str())->c_str();
     SDValue ProtoOps[] = {
@@ -1734,9 +1736,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
   if (isIndirectCall) {
     SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-    SDValue PrototypeOps[] = { Chain,
-                               DAG.getConstant(uniqueCallSite, dl, MVT::i32),
-                               InFlag };
+    SDValue PrototypeOps[] = {
+        Chain, DAG.getConstant(UniqueCallSite, dl, MVT::i32), InFlag};
     Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps);
     InFlag = Chain.getValue(1);
   }
@@ -1832,13 +1833,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     }
   }
 
-  Chain = DAG.getCALLSEQ_END(Chain,
-                             DAG.getIntPtrConstant(uniqueCallSite, dl, true),
-                             DAG.getIntPtrConstant(uniqueCallSite + 1, dl,
-                                                   true),
-                             InFlag, dl);
+  Chain = DAG.getCALLSEQ_END(
+      Chain, DAG.getIntPtrConstant(UniqueCallSite, dl, true),
+      DAG.getIntPtrConstant(UniqueCallSite + 1, dl, true), InFlag, dl);
   InFlag = Chain.getValue(1);
-  uniqueCallSite++;
 
   // Append ProxyReg instructions to the chain to make sure that `callseq_end`
   // will not get lost. Otherwise, during libcalls expansion, the nodes can become
index df9cd41..13829b9 100644 (file)
@@ -491,7 +491,8 @@ public:
 
   std::string getPrototype(const DataLayout &DL, Type *, const ArgListTy &,
                            const SmallVectorImpl<ISD::OutputArg> &,
-                           MaybeAlign retAlignment, const CallBase &CB) const;
+                           MaybeAlign retAlignment, const CallBase &CB,
+                           unsigned UniqueCallSite) const;
 
   SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
                       const SmallVectorImpl<ISD::OutputArg> &Outs,