From 0c7cd071f7916a4f9a0bdc70a58b8477c3700e38 Mon Sep 17 00:00:00 2001 From: Clement Courbet Date: Wed, 25 Oct 2017 11:02:09 +0000 Subject: [PATCH] Re-land "[CodeGen][ExpandMemcmp][NFC] Allow memcmp to expand to vector loads (1)" Compute the actual decomposition only after deciding whether to expand of not. Else, it's easy to make the compiler OOM with: `memcpy(dst, src, 0xffffffffffffffff);`, which typically happens if someone mistakenly passes a negative value. Add a test. This reverts commit f8fc02fbd4ab33383c010d33675acf9763d0bd44. llvm-svn: 316567 --- llvm/lib/CodeGen/CodeGenPrepare.cpp | 422 +++++++++++++++++++----------------- llvm/test/CodeGen/X86/memcmp.ll | 7 + 2 files changed, 234 insertions(+), 195 deletions(-) diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp index 9f0c1f7..1e5f153 100644 --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -1710,43 +1710,69 @@ class MemCmpExpansion { ResultBlock() = default; }; - CallInst *CI; + CallInst *const CI; ResultBlock ResBlock; + const uint64_t Size; unsigned MaxLoadSize; - unsigned NumBlocks; - unsigned NumBlocksNonOneByte; - unsigned NumLoadsPerBlock; + uint64_t NumLoads; + uint64_t NumLoadsNonOneByte; + const uint64_t NumLoadsPerBlock; std::vector LoadCmpBlocks; BasicBlock *EndBlock; PHINode *PhiRes; - bool IsUsedForZeroCmp; + const bool IsUsedForZeroCmp; const DataLayout &DL; IRBuilder<> Builder; + // Represents the decomposition in blocks of the expansion. For example, + // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and + // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}. + // TODO(courbet): Involve the target more in this computation. On X86, 7 + // bytes can be done more efficiently with two overlaping 4-byte loads than + // covering the interval with [{4, 0},{2, 4},{1, 6}}. + struct LoadEntry { + LoadEntry(unsigned LoadSize, uint64_t Offset) + : LoadSize(LoadSize), Offset(Offset) { + assert(Offset % LoadSize == 0 && "invalid load entry"); + } + + uint64_t getGEPIndex() const { return Offset / LoadSize; } + + // The size of the load for this block, in bytes. + const unsigned LoadSize; + // The offset of this load WRT the base pointer, in bytes. + const uint64_t Offset; + }; + SmallVector LoadSequence; + void computeLoadSequence(); - unsigned calculateNumBlocks(unsigned Size); void createLoadCmpBlocks(); void createResultBlock(); void setupResultBlockPHINodes(); void setupEndBlockPHINodes(); - void emitLoadCompareBlock(unsigned Index, unsigned LoadSize, - unsigned GEPIndex); - Value *getCompareLoadPairs(unsigned Index, unsigned Size, - unsigned &NumBytesProcessed); - void emitLoadCompareBlockMultipleLoads(unsigned Index, unsigned Size, - unsigned &NumBytesProcessed); - void emitLoadCompareByteBlock(unsigned Index, unsigned GEPIndex); + Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex); + void emitLoadCompareBlock(unsigned BlockIndex); + void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, + unsigned &LoadIndex); + void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned GEPIndex); void emitMemCmpResultBlock(); - Value *getMemCmpExpansionZeroCase(unsigned Size); - Value *getMemCmpEqZeroOneBlock(unsigned Size); - Value *getMemCmpOneBlock(unsigned Size); - unsigned getLoadSize(unsigned Size); - unsigned getNumLoads(unsigned Size); + Value *getMemCmpExpansionZeroCase(); + Value *getMemCmpEqZeroOneBlock(); + Value *getMemCmpOneBlock(); -public: + // Computes the decomposition. THis is the common code to compute the number + // of loads and the actual load sequence. `callback` is called with each load + // size and number of loads for the block size. + template + void getDecomposition(CallBackT callback) const; + + public: MemCmpExpansion(CallInst *CI, uint64_t Size, unsigned MaxLoadSize, unsigned NumLoadsPerBlock, const DataLayout &DL); - Value *getMemCmpExpansion(uint64_t Size); + unsigned getNumBlocks(); + uint64_t getNumLoads() const { return NumLoads; } + + Value *getMemCmpExpansion(); }; } // end anonymous namespace @@ -1759,43 +1785,74 @@ public: // return from. // 3. ResultBlock, block to branch to for early exit when a // LoadCmpBlock finds a difference. -MemCmpExpansion::MemCmpExpansion(CallInst *CI, uint64_t Size, - unsigned MaxLoadSize, unsigned LoadsPerBlock, +MemCmpExpansion::MemCmpExpansion(CallInst *const CI, uint64_t Size, + const unsigned MaxLoadSize, + const unsigned LoadsPerBlock, const DataLayout &TheDataLayout) - : CI(CI), MaxLoadSize(MaxLoadSize), NumLoadsPerBlock(LoadsPerBlock), - DL(TheDataLayout), Builder(CI) { - // A memcmp with zero-comparison with only one block of load and compare does - // not need to set up any extra blocks. This case could be handled in the DAG, - // but since we have all of the machinery to flexibly expand any memcpy here, - // we choose to handle this case too to avoid fragmented lowering. - IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI); - NumBlocks = calculateNumBlocks(Size); - if ((!IsUsedForZeroCmp && NumLoadsPerBlock != 1) || NumBlocks != 1) { - BasicBlock *StartBlock = CI->getParent(); - EndBlock = StartBlock->splitBasicBlock(CI, "endblock"); - setupEndBlockPHINodes(); - createResultBlock(); - - // If return value of memcmp is not used in a zero equality, we need to - // calculate which source was larger. The calculation requires the - // two loaded source values of each load compare block. - // These will be saved in the phi nodes created by setupResultBlockPHINodes. - if (!IsUsedForZeroCmp) - setupResultBlockPHINodes(); - - // Create the number of required load compare basic blocks. - createLoadCmpBlocks(); + : CI(CI), + Size(Size), + MaxLoadSize(MaxLoadSize), + NumLoads(0), + NumLoadsNonOneByte(0), + NumLoadsPerBlock(LoadsPerBlock), + IsUsedForZeroCmp(isOnlyUsedInZeroEqualityComparison(CI)), + DL(TheDataLayout), + Builder(CI) { + // Scale the max size down if the target can load more bytes than we need. + while (this->MaxLoadSize > Size) { + this->MaxLoadSize /= 2; + } + // Compute the number of loads. At that point we don't want to compute the + // actual decomposition because it might be too large to fit in memory. + getDecomposition([this](unsigned LoadSize, uint64_t NumLoadsForSize) { + NumLoads += NumLoadsForSize; + }); +} - // Update the terminator added by splitBasicBlock to branch to the first - // LoadCmpBlock. - StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]); +template +void MemCmpExpansion::getDecomposition(CallBackT callback) const { + unsigned LoadSize = this->MaxLoadSize; + assert(Size > 0 && "zero blocks"); + uint64_t CurSize = Size; + while (CurSize) { + assert(LoadSize > 0 && "zero load size"); + const uint64_t NumLoadsForThisSize = CurSize / LoadSize; + if (NumLoadsForThisSize > 0) { + callback(LoadSize, NumLoadsForThisSize); + CurSize = CurSize % LoadSize; + } + // FIXME: This can result in a non-native load size (e.g. X86-32+SSE can + // load 16 and 4 but not 8), which throws the load count off (e.g. in the + // aforementioned case, 16 bytes will count for 2 loads but will generate + // 4). + LoadSize /= 2; } +} - Builder.SetCurrentDebugLocation(CI->getDebugLoc()); +void MemCmpExpansion::computeLoadSequence() { + uint64_t Offset = 0; + getDecomposition( + [this, &Offset](unsigned LoadSize, uint64_t NumLoadsForSize) { + for (uint64_t I = 0; I < NumLoadsForSize; ++I) { + LoadSequence.push_back({LoadSize, Offset}); + Offset += LoadSize; + } + if (LoadSize > 1) { + ++NumLoadsNonOneByte; + } + }); + assert(LoadSequence.size() == getNumLoads() && "mismatch in numbe rof loads"); +} + +unsigned MemCmpExpansion::getNumBlocks() { + if (IsUsedForZeroCmp) + return getNumLoads() / NumLoadsPerBlock + + (getNumLoads() % NumLoadsPerBlock != 0 ? 1 : 0); + return getNumLoads(); } void MemCmpExpansion::createLoadCmpBlocks() { - for (unsigned i = 0; i < NumBlocks; i++) { + for (unsigned i = 0; i < getNumBlocks(); i++) { BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb", EndBlock->getParent(), EndBlock); LoadCmpBlocks.push_back(BB); @@ -1811,12 +1868,12 @@ void MemCmpExpansion::createResultBlock() { // It loads 1 byte from each source of the memcmp parameters with the given // GEPIndex. It then subtracts the two loaded values and adds this result to the // final phi node for selecting the memcmp result. -void MemCmpExpansion::emitLoadCompareByteBlock(unsigned Index, +void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex, unsigned GEPIndex) { Value *Source1 = CI->getArgOperand(0); Value *Source2 = CI->getArgOperand(1); - Builder.SetInsertPoint(LoadCmpBlocks[Index]); + Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); Type *LoadSizeType = Type::getInt8Ty(CI->getContext()); // Cast source to LoadSizeType*. if (Source1->getType() != LoadSizeType) @@ -1839,15 +1896,15 @@ void MemCmpExpansion::emitLoadCompareByteBlock(unsigned Index, LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext())); Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2); - PhiRes->addIncoming(Diff, LoadCmpBlocks[Index]); + PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]); - if (Index < (LoadCmpBlocks.size() - 1)) { + if (BlockIndex < (LoadCmpBlocks.size() - 1)) { // Early exit branch if difference found to EndBlock. Otherwise, continue to // next LoadCmpBlock, Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff, ConstantInt::get(Diff->getType(), 0)); BranchInst *CmpBr = - BranchInst::Create(EndBlock, LoadCmpBlocks[Index + 1], Cmp); + BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp); Builder.Insert(CmpBr); } else { // The last block has an unconditional branch to EndBlock. @@ -1856,42 +1913,37 @@ void MemCmpExpansion::emitLoadCompareByteBlock(unsigned Index, } } -unsigned MemCmpExpansion::getNumLoads(unsigned Size) { - return (Size / MaxLoadSize) + countPopulation(Size % MaxLoadSize); -} - -unsigned MemCmpExpansion::getLoadSize(unsigned Size) { - return MinAlign(PowerOf2Floor(Size), MaxLoadSize); -} - /// Generate an equality comparison for one or more pairs of loaded values. /// This is used in the case where the memcmp() call is compared equal or not /// equal to zero. -Value *MemCmpExpansion::getCompareLoadPairs(unsigned Index, unsigned Size, - unsigned &NumBytesProcessed) { +Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex, + unsigned &LoadIndex) { + assert(LoadIndex < getNumLoads() && + "getCompareLoadPairs() called with no remaining loads"); std::vector XorList, OrList; Value *Diff; - unsigned RemainingBytes = Size - NumBytesProcessed; - unsigned NumLoadsRemaining = getNumLoads(RemainingBytes); - unsigned NumLoads = std::min(NumLoadsRemaining, NumLoadsPerBlock); + const unsigned NumLoads = + std::min(getNumLoads() - LoadIndex, NumLoadsPerBlock); // For a single-block expansion, start inserting before the memcmp call. if (LoadCmpBlocks.empty()) Builder.SetInsertPoint(CI); else - Builder.SetInsertPoint(LoadCmpBlocks[Index]); + Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); Value *Cmp = nullptr; - for (unsigned i = 0; i < NumLoads; ++i) { - unsigned LoadSize = getLoadSize(RemainingBytes); - unsigned GEPIndex = NumBytesProcessed / LoadSize; - NumBytesProcessed += LoadSize; - RemainingBytes -= LoadSize; - - Type *LoadSizeType = IntegerType::get(CI->getContext(), LoadSize * 8); - Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); - assert(LoadSize <= MaxLoadSize && "Unexpected load type"); + // If we have multiple loads per block, we need to generate a composite + // comparison using xor+or. The type for the combinations is the largest load + // type. + IntegerType *const MaxLoadType = + NumLoads == 1 ? nullptr + : IntegerType::get(CI->getContext(), MaxLoadSize * 8); + for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) { + const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex]; + + IntegerType *LoadSizeType = + IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); Value *Source1 = CI->getArgOperand(0); Value *Source2 = CI->getArgOperand(1); @@ -1902,12 +1954,14 @@ Value *MemCmpExpansion::getCompareLoadPairs(unsigned Index, unsigned Size, if (Source2->getType() != LoadSizeType) Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); - // Get the base address using the GEPIndex. - if (GEPIndex != 0) { - Source1 = Builder.CreateGEP(LoadSizeType, Source1, - ConstantInt::get(LoadSizeType, GEPIndex)); - Source2 = Builder.CreateGEP(LoadSizeType, Source2, - ConstantInt::get(LoadSizeType, GEPIndex)); + // Get the base address using a GEP. + if (CurLoadEntry.Offset != 0) { + Source1 = Builder.CreateGEP( + LoadSizeType, Source1, + ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); + Source2 = Builder.CreateGEP( + LoadSizeType, Source2, + ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); } // Get a constant or load a value for each source address. @@ -1964,13 +2018,13 @@ Value *MemCmpExpansion::getCompareLoadPairs(unsigned Index, unsigned Size, return Cmp; } -void MemCmpExpansion::emitLoadCompareBlockMultipleLoads( - unsigned Index, unsigned Size, unsigned &NumBytesProcessed) { - Value *Cmp = getCompareLoadPairs(Index, Size, NumBytesProcessed); +void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, + unsigned &LoadIndex) { + Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex); - BasicBlock *NextBB = (Index == (LoadCmpBlocks.size() - 1)) + BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) ? EndBlock - : LoadCmpBlocks[Index + 1]; + : LoadCmpBlocks[BlockIndex + 1]; // Early exit branch if difference found to ResultBlock. Otherwise, // continue to next LoadCmpBlock or EndBlock. BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp); @@ -1979,9 +2033,9 @@ void MemCmpExpansion::emitLoadCompareBlockMultipleLoads( // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 // since early exit to ResultBlock was not taken (no difference was found in // any of the bytes). - if (Index == LoadCmpBlocks.size() - 1) { + if (BlockIndex == LoadCmpBlocks.size() - 1) { Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); - PhiRes->addIncoming(Zero, LoadCmpBlocks[Index]); + PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); } } @@ -1994,33 +2048,39 @@ void MemCmpExpansion::emitLoadCompareBlockMultipleLoads( // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with // a special case through emitLoadCompareByteBlock. The special handling can // simply subtract the loaded values and add it to the result phi node. -void MemCmpExpansion::emitLoadCompareBlock(unsigned Index, unsigned LoadSize, - unsigned GEPIndex) { - if (LoadSize == 1) { - MemCmpExpansion::emitLoadCompareByteBlock(Index, GEPIndex); +void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) { + // There is one load per block in this case, BlockIndex == LoadIndex. + const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex]; + + if (CurLoadEntry.LoadSize == 1) { + MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, + CurLoadEntry.getGEPIndex()); return; } - Type *LoadSizeType = IntegerType::get(CI->getContext(), LoadSize * 8); + Type *LoadSizeType = + IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); - assert(LoadSize <= MaxLoadSize && "Unexpected load type"); + assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type"); Value *Source1 = CI->getArgOperand(0); Value *Source2 = CI->getArgOperand(1); - Builder.SetInsertPoint(LoadCmpBlocks[Index]); + Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); // Cast source to LoadSizeType*. if (Source1->getType() != LoadSizeType) Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); if (Source2->getType() != LoadSizeType) Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); - // Get the base address using the GEPIndex. - if (GEPIndex != 0) { - Source1 = Builder.CreateGEP(LoadSizeType, Source1, - ConstantInt::get(LoadSizeType, GEPIndex)); - Source2 = Builder.CreateGEP(LoadSizeType, Source2, - ConstantInt::get(LoadSizeType, GEPIndex)); + // Get the base address using a GEP. + if (CurLoadEntry.Offset != 0) { + Source1 = Builder.CreateGEP( + LoadSizeType, Source1, + ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); + Source2 = Builder.CreateGEP( + LoadSizeType, Source2, + ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); } // Load LoadSizeType from the base address. @@ -2042,14 +2102,14 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned Index, unsigned LoadSize, // Add the loaded values to the phi nodes for calculating memcmp result only // if result is not used in a zero equality. if (!IsUsedForZeroCmp) { - ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[Index]); - ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[Index]); + ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]); + ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]); } Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2); - BasicBlock *NextBB = (Index == (LoadCmpBlocks.size() - 1)) + BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) ? EndBlock - : LoadCmpBlocks[Index + 1]; + : LoadCmpBlocks[BlockIndex + 1]; // Early exit branch if difference found to ResultBlock. Otherwise, continue // to next LoadCmpBlock or EndBlock. BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp); @@ -2058,9 +2118,9 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned Index, unsigned LoadSize, // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 // since early exit to ResultBlock was not taken (no difference was found in // any of the bytes). - if (Index == LoadCmpBlocks.size() - 1) { + if (BlockIndex == LoadCmpBlocks.size() - 1) { Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); - PhiRes->addIncoming(Zero, LoadCmpBlocks[Index]); + PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); } } @@ -2094,34 +2154,14 @@ void MemCmpExpansion::emitMemCmpResultBlock() { PhiRes->addIncoming(Res, ResBlock.BB); } -unsigned MemCmpExpansion::calculateNumBlocks(unsigned Size) { - unsigned NumBlocks = 0; - bool HaveOneByteLoad = false; - unsigned RemainingSize = Size; - unsigned LoadSize = MaxLoadSize; - while (RemainingSize) { - if (LoadSize == 1) - HaveOneByteLoad = true; - NumBlocks += RemainingSize / LoadSize; - RemainingSize = RemainingSize % LoadSize; - LoadSize = LoadSize / 2; - } - NumBlocksNonOneByte = HaveOneByteLoad ? (NumBlocks - 1) : NumBlocks; - - if (IsUsedForZeroCmp) - NumBlocks = NumBlocks / NumLoadsPerBlock + - (NumBlocks % NumLoadsPerBlock != 0 ? 1 : 0); - - return NumBlocks; -} - void MemCmpExpansion::setupResultBlockPHINodes() { Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); Builder.SetInsertPoint(ResBlock.BB); + // Note: this assumes one load per block. ResBlock.PhiSrc1 = - Builder.CreatePHI(MaxLoadType, NumBlocksNonOneByte, "phi.src1"); + Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1"); ResBlock.PhiSrc2 = - Builder.CreatePHI(MaxLoadType, NumBlocksNonOneByte, "phi.src2"); + Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2"); } void MemCmpExpansion::setupEndBlockPHINodes() { @@ -2129,12 +2169,13 @@ void MemCmpExpansion::setupEndBlockPHINodes() { PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res"); } -Value *MemCmpExpansion::getMemCmpExpansionZeroCase(unsigned Size) { - unsigned NumBytesProcessed = 0; +Value *MemCmpExpansion::getMemCmpExpansionZeroCase() { + unsigned LoadIndex = 0; // This loop populates each of the LoadCmpBlocks with the IR sequence to // handle multiple loads per block. - for (unsigned i = 0; i < NumBlocks; ++i) - emitLoadCompareBlockMultipleLoads(i, Size, NumBytesProcessed); + for (unsigned I = 0; I < getNumBlocks(); ++I) { + emitLoadCompareBlockMultipleLoads(I, LoadIndex); + } emitMemCmpResultBlock(); return PhiRes; @@ -2143,15 +2184,16 @@ Value *MemCmpExpansion::getMemCmpExpansionZeroCase(unsigned Size) { /// A memcmp expansion that compares equality with 0 and only has one block of /// load and compare can bypass the compare, branch, and phi IR that is required /// in the general case. -Value *MemCmpExpansion::getMemCmpEqZeroOneBlock(unsigned Size) { - unsigned NumBytesProcessed = 0; - Value *Cmp = getCompareLoadPairs(0, Size, NumBytesProcessed); +Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() { + unsigned LoadIndex = 0; + Value *Cmp = getCompareLoadPairs(0, LoadIndex); + assert(LoadIndex == getNumLoads() && "some entries were not consumed"); return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext())); } /// A memcmp expansion that only has one block of load and compare can bypass /// the compare, branch, and phi IR that is required in the general case. -Value *MemCmpExpansion::getMemCmpOneBlock(unsigned Size) { +Value *MemCmpExpansion::getMemCmpOneBlock() { assert(NumLoadsPerBlock == 1 && "Only handles one load pair per block"); Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8); @@ -2198,37 +2240,43 @@ Value *MemCmpExpansion::getMemCmpOneBlock(unsigned Size) { // This function expands the memcmp call into an inline expansion and returns // the memcmp result. -Value *MemCmpExpansion::getMemCmpExpansion(uint64_t Size) { +Value *MemCmpExpansion::getMemCmpExpansion() { + computeLoadSequence(); + // A memcmp with zero-comparison with only one block of load and compare does + // not need to set up any extra blocks. This case could be handled in the DAG, + // but since we have all of the machinery to flexibly expand any memcpy here, + // we choose to handle this case too to avoid fragmented lowering. + if ((!IsUsedForZeroCmp && NumLoadsPerBlock != 1) || getNumBlocks() != 1) { + BasicBlock *StartBlock = CI->getParent(); + EndBlock = StartBlock->splitBasicBlock(CI, "endblock"); + setupEndBlockPHINodes(); + createResultBlock(); + + // If return value of memcmp is not used in a zero equality, we need to + // calculate which source was larger. The calculation requires the + // two loaded source values of each load compare block. + // These will be saved in the phi nodes created by setupResultBlockPHINodes. + if (!IsUsedForZeroCmp) setupResultBlockPHINodes(); + + // Create the number of required load compare basic blocks. + createLoadCmpBlocks(); + + // Update the terminator added by splitBasicBlock to branch to the first + // LoadCmpBlock. + StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]); + } + + Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + if (IsUsedForZeroCmp) - return NumBlocks == 1 ? getMemCmpEqZeroOneBlock(Size) : - getMemCmpExpansionZeroCase(Size); + return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock() + : getMemCmpExpansionZeroCase(); // TODO: Handle more than one load pair per block in getMemCmpOneBlock(). - if (NumBlocks == 1 && NumLoadsPerBlock == 1) - return getMemCmpOneBlock(Size); - - // This loop calls emitLoadCompareBlock for comparing Size bytes of the two - // memcmp sources. It starts with loading using the maximum load size set by - // the target. It processes any remaining bytes using a load size which is the - // next smallest power of 2. - unsigned LoadSize = MaxLoadSize; - unsigned NumBytesToBeProcessed = Size; - unsigned Index = 0; - while (NumBytesToBeProcessed) { - // Calculate how many blocks we can create with the current load size. - unsigned NumBlocks = NumBytesToBeProcessed / LoadSize; - unsigned GEPIndex = (Size - NumBytesToBeProcessed) / LoadSize; - NumBytesToBeProcessed = NumBytesToBeProcessed % LoadSize; - - // For each NumBlocks, populate the instruction sequence for loading and - // comparing LoadSize bytes. - while (NumBlocks--) { - emitLoadCompareBlock(Index, LoadSize, GEPIndex); - Index++; - GEPIndex++; - } - // Get the next LoadSize to use. - LoadSize = LoadSize / 2; + if (getNumBlocks() == 1 && NumLoadsPerBlock == 1) return getMemCmpOneBlock(); + + for (unsigned I = 0; I < getNumBlocks(); ++I) { + emitLoadCompareBlock(I); } emitMemCmpResultBlock(); @@ -2312,12 +2360,6 @@ static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI, const TargetLowering *TLI, const DataLayout *DL) { NumMemCmpCalls++; - // TTI call to check if target would like to expand memcmp. Also, get the - // MaxLoadSize. - unsigned MaxLoadSize; - if (!TTI->enableMemCmpExpansion(MaxLoadSize)) - return false; - // Early exit from expansion if -Oz. if (CI->getFunction()->optForMinSize()) return false; @@ -2328,36 +2370,26 @@ static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI, NumMemCmpNotConstant++; return false; } + const uint64_t SizeVal = SizeCast->getZExtValue(); - // Scale the max size down if the target can load more bytes than we need. - uint64_t SizeVal = SizeCast->getZExtValue(); - if (MaxLoadSize > SizeVal) - MaxLoadSize = 1 << SizeCast->getValue().logBase2(); + // TTI call to check if target would like to expand memcmp. Also, get the + // max LoadSize. + unsigned MaxLoadSize; + if (!TTI->enableMemCmpExpansion(MaxLoadSize)) return false; - // Calculate how many load pairs are needed for the constant size. - unsigned NumLoads = 0; - unsigned RemainingSize = SizeVal; - unsigned LoadSize = MaxLoadSize; - while (RemainingSize) { - NumLoads += RemainingSize / LoadSize; - RemainingSize = RemainingSize % LoadSize; - LoadSize = LoadSize / 2; - } + MemCmpExpansion Expansion(CI, SizeVal, MaxLoadSize, MemCmpNumLoadsPerBlock, + *DL); // Don't expand if this will require more loads than desired by the target. - if (NumLoads > TLI->getMaxExpandSizeMemcmp(CI->getFunction()->optForSize())) { + if (Expansion.getNumLoads() > + TLI->getMaxExpandSizeMemcmp(CI->getFunction()->optForSize())) { NumMemCmpGreaterThanMax++; return false; } NumMemCmpInlined++; - // MemCmpHelper object creates and sets up basic blocks required for - // expanding memcmp with size SizeVal. - unsigned NumLoadsPerBlock = MemCmpNumLoadsPerBlock; - MemCmpExpansion MemCmpHelper(CI, SizeVal, MaxLoadSize, NumLoadsPerBlock, *DL); - - Value *Res = MemCmpHelper.getMemCmpExpansion(SizeVal); + Value *Res = Expansion.getMemCmpExpansion(); // Replace call with result of expansion and erase call. CI->replaceAllUsesWith(Res); diff --git a/llvm/test/CodeGen/X86/memcmp.ll b/llvm/test/CodeGen/X86/memcmp.ll index b4d5148..04f0856 100644 --- a/llvm/test/CodeGen/X86/memcmp.ll +++ b/llvm/test/CodeGen/X86/memcmp.ll @@ -965,3 +965,10 @@ define i1 @length64_eq_const(i8* %X) nounwind { ret i1 %c } +; This checks that we do not do stupid things with huge sizes. +define i32 @huge_length(i8* %X, i8* %Y) nounwind { + %m = tail call i32 @memcmp(i8* %X, i8* %Y, i64 9223372036854775807) nounwind + ret i32 %m +} + + -- 2.7.4