Add support for atomic memory copy lowering
authorEvgeniy Brevnov <ybrevnov@azul.com>
Fri, 4 Feb 2022 03:54:27 +0000 (10:54 +0700)
committerEvgeniy Brevnov <ybrevnov@azul.com>
Fri, 8 Apr 2022 03:41:31 +0000 (10:41 +0700)
Currently, the utility supports lowering of non atomic memory transfer routines only. This patch adds support for atomic version of memcopy. This may be useful for targets not supporting atomic memcopy.

Reviewed By: arsenm

Differential Revision: https://reviews.llvm.org/D118443

llvm/include/llvm/Analysis/TargetTransformInfo.h
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
llvm/include/llvm/Transforms/Utils/LowerMemIntrinsics.h
llvm/lib/Analysis/TargetTransformInfo.cpp
llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp
llvm/unittests/Transforms/Utils/MemTransferLowering.cpp

index f9e44fc..acda8a0 100644 (file)
@@ -1290,9 +1290,11 @@ public:
                                            Type *ExpectedType) const;
 
   /// \returns The type to use in a loop expansion of a memcpy call.
-  Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
-                                  unsigned SrcAddrSpace, unsigned DestAddrSpace,
-                                  unsigned SrcAlign, unsigned DestAlign) const;
+  Type *
+  getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
+                            unsigned SrcAddrSpace, unsigned DestAddrSpace,
+                            unsigned SrcAlign, unsigned DestAlign,
+                            Optional<uint32_t> AtomicElementSize = None) const;
 
   /// \param[out] OpsOut The operand types to copy RemainingBytes of memory.
   /// \param RemainingBytes The number of bytes to copy.
@@ -1303,7 +1305,8 @@ public:
   void getMemcpyLoopResidualLoweringType(
       SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
       unsigned RemainingBytes, unsigned SrcAddrSpace, unsigned DestAddrSpace,
-      unsigned SrcAlign, unsigned DestAlign) const;
+      unsigned SrcAlign, unsigned DestAlign,
+      Optional<uint32_t> AtomicCpySize = None) const;
 
   /// \returns True if the two functions have compatible attributes for inlining
   /// purposes.
@@ -1745,15 +1748,17 @@ public:
   virtual unsigned getAtomicMemIntrinsicMaxElementSize() const = 0;
   virtual Value *getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst,
                                                    Type *ExpectedType) = 0;
-  virtual Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
-                                          unsigned SrcAddrSpace,
-                                          unsigned DestAddrSpace,
-                                          unsigned SrcAlign,
-                                          unsigned DestAlign) const = 0;
+  virtual Type *
+  getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
+                            unsigned SrcAddrSpace, unsigned DestAddrSpace,
+                            unsigned SrcAlign, unsigned DestAlign,
+                            Optional<uint32_t> AtomicElementSize) const = 0;
+
   virtual void getMemcpyLoopResidualLoweringType(
       SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
       unsigned RemainingBytes, unsigned SrcAddrSpace, unsigned DestAddrSpace,
-      unsigned SrcAlign, unsigned DestAlign) const = 0;
+      unsigned SrcAlign, unsigned DestAlign,
+      Optional<uint32_t> AtomicCpySize) const = 0;
   virtual bool areInlineCompatible(const Function *Caller,
                                    const Function *Callee) const = 0;
   virtual bool areTypesABICompatible(const Function *Caller,
@@ -2315,20 +2320,22 @@ public:
                                            Type *ExpectedType) override {
     return Impl.getOrCreateResultFromMemIntrinsic(Inst, ExpectedType);
   }
-  Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
-                                  unsigned SrcAddrSpace, unsigned DestAddrSpace,
-                                  unsigned SrcAlign,
-                                  unsigned DestAlign) const override {
+  Type *getMemcpyLoopLoweringType(
+      LLVMContext &Context, Value *Length, unsigned SrcAddrSpace,
+      unsigned DestAddrSpace, unsigned SrcAlign, unsigned DestAlign,
+      Optional<uint32_t> AtomicElementSize) const override {
     return Impl.getMemcpyLoopLoweringType(Context, Length, SrcAddrSpace,
-                                          DestAddrSpace, SrcAlign, DestAlign);
+                                          DestAddrSpace, SrcAlign, DestAlign,
+                                          AtomicElementSize);
   }
   void getMemcpyLoopResidualLoweringType(
       SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
       unsigned RemainingBytes, unsigned SrcAddrSpace, unsigned DestAddrSpace,
-      unsigned SrcAlign, unsigned DestAlign) const override {
+      unsigned SrcAlign, unsigned DestAlign,
+      Optional<uint32_t> AtomicCpySize) const override {
     Impl.getMemcpyLoopResidualLoweringType(OpsOut, Context, RemainingBytes,
                                            SrcAddrSpace, DestAddrSpace,
-                                           SrcAlign, DestAlign);
+                                           SrcAlign, DestAlign, AtomicCpySize);
   }
   bool areInlineCompatible(const Function *Caller,
                            const Function *Callee) const override {
index 3d0d4f4..42d38ac 100644 (file)
@@ -703,16 +703,21 @@ public:
 
   Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
                                   unsigned SrcAddrSpace, unsigned DestAddrSpace,
-                                  unsigned SrcAlign, unsigned DestAlign) const {
-    return Type::getInt8Ty(Context);
+                                  unsigned SrcAlign, unsigned DestAlign,
+                                  Optional<uint32_t> AtomicElementSize) const {
+    return AtomicElementSize ? Type::getIntNTy(Context, *AtomicElementSize * 8)
+                             : Type::getInt8Ty(Context);
   }
 
   void getMemcpyLoopResidualLoweringType(
       SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
       unsigned RemainingBytes, unsigned SrcAddrSpace, unsigned DestAddrSpace,
-      unsigned SrcAlign, unsigned DestAlign) const {
-    for (unsigned i = 0; i != RemainingBytes; ++i)
-      OpsOut.push_back(Type::getInt8Ty(Context));
+      unsigned SrcAlign, unsigned DestAlign,
+      Optional<uint32_t> AtomicCpySize) const {
+    unsigned OpSizeInBytes = AtomicCpySize ? *AtomicCpySize : 1;
+    Type *OpType = Type::getIntNTy(Context, OpSizeInBytes * 8);
+    for (unsigned i = 0; i != RemainingBytes; i += OpSizeInBytes)
+      OpsOut.push_back(OpType);
   }
 
   bool areInlineCompatible(const Function *Caller,
index a46b7d4..acf59ff 100644 (file)
 #ifndef LLVM_TRANSFORMS_UTILS_LOWERMEMINTRINSICS_H
 #define LLVM_TRANSFORMS_UTILS_LOWERMEMINTRINSICS_H
 
+#include "llvm/ADT/Optional.h"
+
 namespace llvm {
 
+class AtomicMemCpyInst;
 class ConstantInt;
 class Instruction;
 class MemCpyInst;
@@ -32,7 +35,8 @@ void createMemCpyLoopUnknownSize(Instruction *InsertBefore, Value *SrcAddr,
                                  Value *DstAddr, Value *CopyLen, Align SrcAlign,
                                  Align DestAlign, bool SrcIsVolatile,
                                  bool DstIsVolatile, bool CanOverlap,
-                                 const TargetTransformInfo &TTI);
+                                 const TargetTransformInfo &TTI,
+                                 Optional<unsigned> AtomicSize = None);
 
 /// Emit a loop implementing the semantics of an llvm.memcpy whose size is a
 /// compile time constant. Loop is inserted at \p InsertBefore.
@@ -40,7 +44,8 @@ void createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr,
                                Value *DstAddr, ConstantInt *CopyLen,
                                Align SrcAlign, Align DestAlign,
                                bool SrcIsVolatile, bool DstIsVolatile,
-                               bool CanOverlap, const TargetTransformInfo &TTI);
+                               bool CanOverlap, const TargetTransformInfo &TTI,
+                               Optional<uint32_t> AtomicCpySize = None);
 
 /// Expand \p MemCpy as a loop. \p MemCpy is not deleted.
 void expandMemCpyAsLoop(MemCpyInst *MemCpy, const TargetTransformInfo &TTI,
@@ -52,6 +57,11 @@ void expandMemMoveAsLoop(MemMoveInst *MemMove);
 /// Expand \p MemSet as a loop. \p MemSet is not deleted.
 void expandMemSetAsLoop(MemSetInst *MemSet);
 
+/// Expand \p AtomicMemCpy as a loop. \p AtomicMemCpy is not deleted.
+void expandAtomicMemCpyAsLoop(AtomicMemCpyInst *AtomicMemCpy,
+                              const TargetTransformInfo &TTI,
+                              ScalarEvolution *SE);
+
 } // End llvm namespace
 
 #endif
index 7366c70..eded89b 100644 (file)
@@ -976,18 +976,21 @@ Value *TargetTransformInfo::getOrCreateResultFromMemIntrinsic(
 
 Type *TargetTransformInfo::getMemcpyLoopLoweringType(
     LLVMContext &Context, Value *Length, unsigned SrcAddrSpace,
-    unsigned DestAddrSpace, unsigned SrcAlign, unsigned DestAlign) const {
+    unsigned DestAddrSpace, unsigned SrcAlign, unsigned DestAlign,
+    Optional<uint32_t> AtomicElementSize) const {
   return TTIImpl->getMemcpyLoopLoweringType(Context, Length, SrcAddrSpace,
-                                            DestAddrSpace, SrcAlign, DestAlign);
+                                            DestAddrSpace, SrcAlign, DestAlign,
+                                            AtomicElementSize);
 }
 
 void TargetTransformInfo::getMemcpyLoopResidualLoweringType(
     SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
     unsigned RemainingBytes, unsigned SrcAddrSpace, unsigned DestAddrSpace,
-    unsigned SrcAlign, unsigned DestAlign) const {
-  TTIImpl->getMemcpyLoopResidualLoweringType(OpsOut, Context, RemainingBytes,
-                                             SrcAddrSpace, DestAddrSpace,
-                                             SrcAlign, DestAlign);
+    unsigned SrcAlign, unsigned DestAlign,
+    Optional<uint32_t> AtomicCpySize) const {
+  TTIImpl->getMemcpyLoopResidualLoweringType(
+      OpsOut, Context, RemainingBytes, SrcAddrSpace, DestAddrSpace, SrcAlign,
+      DestAlign, AtomicCpySize);
 }
 
 bool TargetTransformInfo::areInlineCompatible(const Function *Caller,
index bdd22a4..0afebe0 100644 (file)
@@ -410,11 +410,14 @@ bool GCNTTIImpl::isLegalToVectorizeStoreChain(unsigned ChainSizeInBytes,
 // unaligned access is legal?
 //
 // FIXME: This could use fine tuning and microbenchmarks.
-Type *GCNTTIImpl::getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
-                                            unsigned SrcAddrSpace,
-                                            unsigned DestAddrSpace,
-                                            unsigned SrcAlign,
-                                            unsigned DestAlign) const {
+Type *GCNTTIImpl::getMemcpyLoopLoweringType(
+    LLVMContext &Context, Value *Length, unsigned SrcAddrSpace,
+    unsigned DestAddrSpace, unsigned SrcAlign, unsigned DestAlign,
+    Optional<uint32_t> AtomicElementSize) const {
+
+  if (AtomicElementSize)
+    return Type::getIntNTy(Context, *AtomicElementSize * 8);
+
   unsigned MinAlign = std::min(SrcAlign, DestAlign);
 
   // A (multi-)dword access at an address == 2 (mod 4) will be decomposed by the
@@ -439,11 +442,17 @@ Type *GCNTTIImpl::getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
 }
 
 void GCNTTIImpl::getMemcpyLoopResidualLoweringType(
-  SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
-  unsigned RemainingBytes, unsigned SrcAddrSpace, unsigned DestAddrSpace,
-  unsigned SrcAlign, unsigned DestAlign) const {
+    SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
+    unsigned RemainingBytes, unsigned SrcAddrSpace, unsigned DestAddrSpace,
+    unsigned SrcAlign, unsigned DestAlign,
+    Optional<uint32_t> AtomicCpySize) const {
   assert(RemainingBytes < 16);
 
+  if (AtomicCpySize)
+    BaseT::getMemcpyLoopResidualLoweringType(
+        OpsOut, Context, RemainingBytes, SrcAddrSpace, DestAddrSpace, SrcAlign,
+        DestAlign, AtomicCpySize);
+
   unsigned MinAlign = std::min(SrcAlign, DestAlign);
 
   if (MinAlign != 2) {
index 4743042..ebeb05e 100644 (file)
@@ -135,15 +135,14 @@ public:
                                     unsigned AddrSpace) const;
   Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
                                   unsigned SrcAddrSpace, unsigned DestAddrSpace,
-                                  unsigned SrcAlign, unsigned DestAlign) const;
-
-  void getMemcpyLoopResidualLoweringType(SmallVectorImpl<Type *> &OpsOut,
-                                         LLVMContext &Context,
-                                         unsigned RemainingBytes,
-                                         unsigned SrcAddrSpace,
-                                         unsigned DestAddrSpace,
-                                         unsigned SrcAlign,
-                                         unsigned DestAlign) const;
+                                  unsigned SrcAlign, unsigned DestAlign,
+                                  Optional<uint32_t> AtomicElementSize) const;
+
+  void getMemcpyLoopResidualLoweringType(
+      SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
+      unsigned RemainingBytes, unsigned SrcAddrSpace, unsigned DestAddrSpace,
+      unsigned SrcAlign, unsigned DestAlign,
+      Optional<uint32_t> AtomicCpySize) const;
   unsigned getMaxInterleaveFactor(unsigned VF);
 
   bool getTgtMemIntrinsic(IntrinsicInst *Inst, MemIntrinsicInfo &Info) const;
index 3848e34..b4acb1b 100644 (file)
@@ -21,7 +21,8 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr,
                                      Align SrcAlign, Align DstAlign,
                                      bool SrcIsVolatile, bool DstIsVolatile,
                                      bool CanOverlap,
-                                     const TargetTransformInfo &TTI) {
+                                     const TargetTransformInfo &TTI,
+                                     Optional<uint32_t> AtomicElementSize) {
   // No need to expand zero length copies.
   if (CopyLen->isZero())
     return;
@@ -41,9 +42,15 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr,
 
   Type *TypeOfCopyLen = CopyLen->getType();
   Type *LoopOpType = TTI.getMemcpyLoopLoweringType(
-      Ctx, CopyLen, SrcAS, DstAS, SrcAlign.value(), DstAlign.value());
+      Ctx, CopyLen, SrcAS, DstAS, SrcAlign.value(), DstAlign.value(),
+      AtomicElementSize);
+  assert((!AtomicElementSize || !LoopOpType->isVectorTy()) &&
+         "Atomic memcpy lowering is not supported for vector operand type");
 
   unsigned LoopOpSize = DL.getTypeStoreSize(LoopOpType);
+  assert((!AtomicElementSize || LoopOpSize % *AtomicElementSize == 0) &&
+      "Atomic memcpy lowering is not supported for selected operand size");
+
   uint64_t LoopEndCount = CopyLen->getZExtValue() / LoopOpSize;
 
   if (LoopEndCount != 0) {
@@ -90,6 +97,10 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr,
       // Indicate that stores don't overlap loads.
       Store->setMetadata(LLVMContext::MD_noalias, MDNode::get(Ctx, NewScope));
     }
+    if (AtomicElementSize) {
+      Load->setAtomic(AtomicOrdering::Unordered);
+      Store->setAtomic(AtomicOrdering::Unordered);
+    }
     Value *NewIndex =
         LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(TypeOfCopyLen, 1U));
     LoopIndex->addIncoming(NewIndex, LoopBB);
@@ -109,7 +120,7 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr,
     SmallVector<Type *, 5> RemainingOps;
     TTI.getMemcpyLoopResidualLoweringType(RemainingOps, Ctx, RemainingBytes,
                                           SrcAS, DstAS, SrcAlign.value(),
-                                          DstAlign.value());
+                                          DstAlign.value(), AtomicElementSize);
 
     for (auto OpTy : RemainingOps) {
       Align PartSrcAlign(commonAlignment(SrcAlign, BytesCopied));
@@ -117,6 +128,10 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr,
 
       // Calaculate the new index
       unsigned OperandSize = DL.getTypeStoreSize(OpTy);
+      assert(
+          (!AtomicElementSize || OperandSize % *AtomicElementSize == 0) &&
+          "Atomic memcpy lowering is not supported for selected operand size");
+
       uint64_t GepIndex = BytesCopied / OperandSize;
       assert(GepIndex * OperandSize == BytesCopied &&
              "Division should have no Remainder!");
@@ -147,6 +162,10 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr,
         // Indicate that stores don't overlap loads.
         Store->setMetadata(LLVMContext::MD_noalias, MDNode::get(Ctx, NewScope));
       }
+      if (AtomicElementSize) {
+        Load->setAtomic(AtomicOrdering::Unordered);
+        Store->setAtomic(AtomicOrdering::Unordered);
+      }
       BytesCopied += OperandSize;
     }
   }
@@ -159,7 +178,8 @@ void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore,
                                        Value *CopyLen, Align SrcAlign,
                                        Align DstAlign, bool SrcIsVolatile,
                                        bool DstIsVolatile, bool CanOverlap,
-                                       const TargetTransformInfo &TTI) {
+                                       const TargetTransformInfo &TTI,
+                                       Optional<uint32_t> AtomicElementSize) {
   BasicBlock *PreLoopBB = InsertBefore->getParent();
   BasicBlock *PostLoopBB =
       PreLoopBB->splitBasicBlock(InsertBefore, "post-loop-memcpy-expansion");
@@ -176,8 +196,13 @@ void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore,
   unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
 
   Type *LoopOpType = TTI.getMemcpyLoopLoweringType(
-      Ctx, CopyLen, SrcAS, DstAS, SrcAlign.value(), DstAlign.value());
+      Ctx, CopyLen, SrcAS, DstAS, SrcAlign.value(), DstAlign.value(),
+      AtomicElementSize);
+  assert((!AtomicElementSize || !LoopOpType->isVectorTy()) &&
+         "Atomic memcpy lowering is not supported for vector operand type");
   unsigned LoopOpSize = DL.getTypeStoreSize(LoopOpType);
+  assert((!AtomicElementSize || LoopOpSize % *AtomicElementSize == 0) &&
+         "Atomic memcpy lowering is not supported for selected operand size");
 
   IRBuilder<> PLBuilder(PreLoopBB->getTerminator());
 
@@ -225,14 +250,27 @@ void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore,
     // Indicate that stores don't overlap loads.
     Store->setMetadata(LLVMContext::MD_noalias, MDNode::get(Ctx, NewScope));
   }
+  if (AtomicElementSize) {
+    Load->setAtomic(AtomicOrdering::Unordered);
+    Store->setAtomic(AtomicOrdering::Unordered);
+  }
   Value *NewIndex =
       LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(CopyLenType, 1U));
   LoopIndex->addIncoming(NewIndex, LoopBB);
 
-  if (!LoopOpIsInt8) {
-   // Add in the
-   Value *RuntimeResidual = PLBuilder.CreateURem(CopyLen, CILoopOpSize);
-   Value *RuntimeBytesCopied = PLBuilder.CreateSub(CopyLen, RuntimeResidual);
+  bool requiresResidual =
+      !LoopOpIsInt8 && !(AtomicElementSize && LoopOpSize == AtomicElementSize);
+  if (requiresResidual) {
+    Type *ResLoopOpType = AtomicElementSize
+                              ? Type::getIntNTy(Ctx, *AtomicElementSize * 8)
+                              : Int8Type;
+    unsigned ResLoopOpSize = DL.getTypeStoreSize(ResLoopOpType);
+    assert((ResLoopOpSize == AtomicElementSize ? *AtomicElementSize : 1) &&
+           "Store size is expected to match type size");
+
+    // Add in the
+    Value *RuntimeResidual = PLBuilder.CreateURem(CopyLen, CILoopOpSize);
+    Value *RuntimeBytesCopied = PLBuilder.CreateSub(CopyLen, RuntimeResidual);
 
     // Loop body for the residual copy.
     BasicBlock *ResLoopBB = BasicBlock::Create(Ctx, "loop-memcpy-residual",
@@ -267,30 +305,34 @@ void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore,
         ResBuilder.CreatePHI(CopyLenType, 2, "residual-loop-index");
     ResidualIndex->addIncoming(Zero, ResHeaderBB);
 
-    Value *SrcAsInt8 =
-        ResBuilder.CreateBitCast(SrcAddr, PointerType::get(Int8Type, SrcAS));
-    Value *DstAsInt8 =
-        ResBuilder.CreateBitCast(DstAddr, PointerType::get(Int8Type, DstAS));
+    Value *SrcAsResLoopOpType = ResBuilder.CreateBitCast(
+        SrcAddr, PointerType::get(ResLoopOpType, SrcAS));
+    Value *DstAsResLoopOpType = ResBuilder.CreateBitCast(
+        DstAddr, PointerType::get(ResLoopOpType, DstAS));
     Value *FullOffset = ResBuilder.CreateAdd(RuntimeBytesCopied, ResidualIndex);
-    Value *SrcGEP =
-        ResBuilder.CreateInBoundsGEP(Int8Type, SrcAsInt8, FullOffset);
-    LoadInst *Load = ResBuilder.CreateAlignedLoad(Int8Type, SrcGEP,
+    Value *SrcGEP = ResBuilder.CreateInBoundsGEP(
+        ResLoopOpType, SrcAsResLoopOpType, FullOffset);
+    LoadInst *Load = ResBuilder.CreateAlignedLoad(ResLoopOpType, SrcGEP,
                                                   PartSrcAlign, SrcIsVolatile);
     if (!CanOverlap) {
       // Set alias scope for loads.
       Load->setMetadata(LLVMContext::MD_alias_scope,
                         MDNode::get(Ctx, NewScope));
     }
-    Value *DstGEP =
-        ResBuilder.CreateInBoundsGEP(Int8Type, DstAsInt8, FullOffset);
+    Value *DstGEP = ResBuilder.CreateInBoundsGEP(
+        ResLoopOpType, DstAsResLoopOpType, FullOffset);
     StoreInst *Store = ResBuilder.CreateAlignedStore(Load, DstGEP, PartDstAlign,
                                                      DstIsVolatile);
     if (!CanOverlap) {
       // Indicate that stores don't overlap loads.
       Store->setMetadata(LLVMContext::MD_noalias, MDNode::get(Ctx, NewScope));
     }
-    Value *ResNewIndex =
-        ResBuilder.CreateAdd(ResidualIndex, ConstantInt::get(CopyLenType, 1U));
+    if (AtomicElementSize) {
+      Load->setAtomic(AtomicOrdering::Unordered);
+      Store->setAtomic(AtomicOrdering::Unordered);
+    }
+    Value *ResNewIndex = ResBuilder.CreateAdd(
+        ResidualIndex, ConstantInt::get(CopyLenType, ResLoopOpSize));
     ResidualIndex->addIncoming(ResNewIndex, ResLoopBB);
 
     // Create the loop branch condition.
@@ -471,17 +513,21 @@ static void createMemSetLoop(Instruction *InsertBefore, Value *DstAddr,
                            NewBB);
 }
 
-void llvm::expandMemCpyAsLoop(MemCpyInst *Memcpy,
-                              const TargetTransformInfo &TTI,
-                              ScalarEvolution *SE) {
-  bool CanOverlap = true;
+template <typename T>
+static bool canOverlap(MemTransferBase<T> *Memcpy, ScalarEvolution *SE) {
   if (SE) {
     auto *SrcSCEV = SE->getSCEV(Memcpy->getRawSource());
     auto *DestSCEV = SE->getSCEV(Memcpy->getRawDest());
     if (SE->isKnownPredicateAt(CmpInst::ICMP_NE, SrcSCEV, DestSCEV, Memcpy))
-      CanOverlap = false;
+      return false;
   }
+  return true;
+}
 
+void llvm::expandMemCpyAsLoop(MemCpyInst *Memcpy,
+                              const TargetTransformInfo &TTI,
+                              ScalarEvolution *SE) {
+  bool CanOverlap = canOverlap(Memcpy, SE);
   if (ConstantInt *CI = dyn_cast<ConstantInt>(Memcpy->getLength())) {
     createMemCpyLoopKnownSize(
         /* InsertBefore */ Memcpy,
@@ -528,3 +574,35 @@ void llvm::expandMemSetAsLoop(MemSetInst *Memset) {
                    /* Alignment */ Memset->getDestAlign().valueOrOne(),
                    Memset->isVolatile());
 }
+
+void llvm::expandAtomicMemCpyAsLoop(AtomicMemCpyInst *AtomicMemcpy,
+                                    const TargetTransformInfo &TTI,
+                                    ScalarEvolution *SE) {
+  if (ConstantInt *CI = dyn_cast<ConstantInt>(AtomicMemcpy->getLength())) {
+    createMemCpyLoopKnownSize(
+        /* InsertBefore */ AtomicMemcpy,
+        /* SrcAddr */ AtomicMemcpy->getRawSource(),
+        /* DstAddr */ AtomicMemcpy->getRawDest(),
+        /* CopyLen */ CI,
+        /* SrcAlign */ AtomicMemcpy->getSourceAlign().valueOrOne(),
+        /* DestAlign */ AtomicMemcpy->getDestAlign().valueOrOne(),
+        /* SrcIsVolatile */ AtomicMemcpy->isVolatile(),
+        /* DstIsVolatile */ AtomicMemcpy->isVolatile(),
+        /* CanOverlap */ false, // SrcAddr & DstAddr may not overlap by spec.
+        /* TargetTransformInfo */ TTI,
+        /* AtomicCpySize */ AtomicMemcpy->getElementSizeInBytes());
+  } else {
+    createMemCpyLoopUnknownSize(
+        /* InsertBefore */ AtomicMemcpy,
+        /* SrcAddr */ AtomicMemcpy->getRawSource(),
+        /* DstAddr */ AtomicMemcpy->getRawDest(),
+        /* CopyLen */ AtomicMemcpy->getLength(),
+        /* SrcAlign */ AtomicMemcpy->getSourceAlign().valueOrOne(),
+        /* DestAlign */ AtomicMemcpy->getDestAlign().valueOrOne(),
+        /* SrcIsVolatile */ AtomicMemcpy->isVolatile(),
+        /* DstIsVolatile */ AtomicMemcpy->isVolatile(),
+        /* CanOverlap */ false, // SrcAddr & DstAddr may not overlap by spec.
+        /* TargetTransformInfo */ TTI,
+        /* AtomicCpySize */ AtomicMemcpy->getElementSizeInBytes());
+  }
+}
index df86e16..62afd91 100644 (file)
@@ -174,4 +174,94 @@ TEST_F(MemTransferLowerTest, VecMemCpyKnownLength) {
 
   MPM.run(*M, MAM);
 }
+
+TEST_F(MemTransferLowerTest, AtomicMemCpyKnownLength) {
+  ParseAssembly("declare void "
+                "@llvm.memcpy.element.unordered.atomic.p0i32.p0i32.i64(i32*, "
+                "i32 *, i64, i32)\n"
+                "define void @foo(i32* %dst, i32* %src, i64 %n) optsize {\n"
+                "entry:\n"
+                "  %is_not_equal = icmp ne i32* %dst, %src\n"
+                "  br i1 %is_not_equal, label %memcpy, label %exit\n"
+                "memcpy:\n"
+                "  call void "
+                "@llvm.memcpy.element.unordered.atomic.p0i32.p0i32.i64(i32* "
+                "%dst, i32* %src, "
+                "i64 1024, i32 4)\n"
+                "  br label %exit\n"
+                "exit:\n"
+                "  ret void\n"
+                "}\n");
+
+  FunctionPassManager FPM;
+  FPM.addPass(ForwardingPass(
+      [=](Function &F, FunctionAnalysisManager &FAM) -> PreservedAnalyses {
+        TargetTransformInfo TTI(M->getDataLayout());
+        auto *MemCpyBB = getBasicBlockByName(F, "memcpy");
+        Instruction *Inst = &MemCpyBB->front();
+        assert(isa<AtomicMemCpyInst>(Inst) &&
+               "Expecting llvm.memcpy.p0i8.i64 instructon");
+        AtomicMemCpyInst *MemCpyI = cast<AtomicMemCpyInst>(Inst);
+        auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(F);
+        expandAtomicMemCpyAsLoop(MemCpyI, TTI, &SE);
+        auto *CopyLoopBB = getBasicBlockByName(F, "load-store-loop");
+        Instruction *LoadInst =
+            getInstructionByOpcode(*CopyLoopBB, Instruction::Load, 1);
+        EXPECT_TRUE(LoadInst->isAtomic());
+        EXPECT_NE(LoadInst->getMetadata(LLVMContext::MD_alias_scope), nullptr);
+        Instruction *StoreInst =
+            getInstructionByOpcode(*CopyLoopBB, Instruction::Store, 1);
+        EXPECT_TRUE(StoreInst->isAtomic());
+        EXPECT_NE(StoreInst->getMetadata(LLVMContext::MD_noalias), nullptr);
+        return PreservedAnalyses::none();
+      }));
+  MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
+
+  MPM.run(*M, MAM);
+}
+
+TEST_F(MemTransferLowerTest, AtomicMemCpyUnKnownLength) {
+  ParseAssembly("declare void "
+                "@llvm.memcpy.element.unordered.atomic.p0i32.p0i32.i64(i32*, "
+                "i32 *, i64, i32)\n"
+                "define void @foo(i32* %dst, i32* %src, i64 %n) optsize {\n"
+                "entry:\n"
+                "  %is_not_equal = icmp ne i32* %dst, %src\n"
+                "  br i1 %is_not_equal, label %memcpy, label %exit\n"
+                "memcpy:\n"
+                "  call void "
+                "@llvm.memcpy.element.unordered.atomic.p0i32.p0i32.i64(i32* "
+                "%dst, i32* %src, "
+                "i64 %n, i32 4)\n"
+                "  br label %exit\n"
+                "exit:\n"
+                "  ret void\n"
+                "}\n");
+
+  FunctionPassManager FPM;
+  FPM.addPass(ForwardingPass(
+      [=](Function &F, FunctionAnalysisManager &FAM) -> PreservedAnalyses {
+        TargetTransformInfo TTI(M->getDataLayout());
+        auto *MemCpyBB = getBasicBlockByName(F, "memcpy");
+        Instruction *Inst = &MemCpyBB->front();
+        assert(isa<AtomicMemCpyInst>(Inst) &&
+               "Expecting llvm.memcpy.p0i8.i64 instructon");
+        AtomicMemCpyInst *MemCpyI = cast<AtomicMemCpyInst>(Inst);
+        auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(F);
+        expandAtomicMemCpyAsLoop(MemCpyI, TTI, &SE);
+        auto *CopyLoopBB = getBasicBlockByName(F, "loop-memcpy-expansion");
+        Instruction *LoadInst =
+            getInstructionByOpcode(*CopyLoopBB, Instruction::Load, 1);
+        EXPECT_TRUE(LoadInst->isAtomic());
+        EXPECT_NE(LoadInst->getMetadata(LLVMContext::MD_alias_scope), nullptr);
+        Instruction *StoreInst =
+            getInstructionByOpcode(*CopyLoopBB, Instruction::Store, 1);
+        EXPECT_TRUE(StoreInst->isAtomic());
+        EXPECT_NE(StoreInst->getMetadata(LLVMContext::MD_noalias), nullptr);
+        return PreservedAnalyses::none();
+      }));
+  MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
+
+  MPM.run(*M, MAM);
+}
 } // namespace