From ae3f54c1e909743a89d48a8a05e18d2c8fd652ba Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Tue, 1 Sep 2020 17:31:12 -0500 Subject: [PATCH] [EarlyCSE] Handle masked loads and stores Extend the handling of memory intrinsics to also include non- target-specific intrinsics, in particular masked loads and stores. Invent "isHandledNonTargetIntrinsic" to distinguish between intrin- sics that should be handled natively from intrinsics that can be passed to TTI. Add code that handles masked loads and stores and update the testcase to reflect the results. Differential Revision: https://reviews.llvm.org/D87340 --- llvm/lib/Transforms/Scalar/EarlyCSE.cpp | 216 +++++++++++++++++++-- .../EarlyCSE/masked-intrinsics-unequal-masks.ll | 10 +- llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll | 4 +- 3 files changed, 202 insertions(+), 28 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp index acdddce..5eb2e12 100644 --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -689,8 +689,33 @@ private: ParseMemoryInst(Instruction *Inst, const TargetTransformInfo &TTI) : Inst(Inst) { if (IntrinsicInst *II = dyn_cast(Inst)) { + IntrID = II->getIntrinsicID(); if (TTI.getTgtMemIntrinsic(II, Info)) - IntrID = II->getIntrinsicID(); + return; + if (isHandledNonTargetIntrinsic(IntrID)) { + switch (IntrID) { + case Intrinsic::masked_load: + Info.PtrVal = Inst->getOperand(0); + Info.MatchingId = Intrinsic::masked_load; + Info.ReadMem = true; + Info.WriteMem = false; + Info.IsVolatile = false; + break; + case Intrinsic::masked_store: + Info.PtrVal = Inst->getOperand(1); + // Use the ID of masked load as the "matching id". This will + // prevent matching non-masked loads/stores with masked ones + // (which could be done), but at the moment, the code here + // does not support matching intrinsics with non-intrinsics, + // so keep the MatchingIds specific to masked instructions + // for now (TODO). + Info.MatchingId = Intrinsic::masked_load; + Info.ReadMem = false; + Info.WriteMem = true; + Info.IsVolatile = false; + break; + } + } } } @@ -747,11 +772,6 @@ private: return false; } - bool isMatchingMemLoc(const ParseMemoryInst &Inst) const { - return (getPointerOperand() == Inst.getPointerOperand() && - getMatchingId() == Inst.getMatchingId()); - } - bool isValid() const { return getPointerOperand() != nullptr; } // For regular (non-intrinsic) loads/stores, this is set to -1. For @@ -788,6 +808,22 @@ private: Instruction *Inst; }; + // This function is to prevent accidentally passing a non-target + // intrinsic ID to TargetTransformInfo. + static bool isHandledNonTargetIntrinsic(Intrinsic::ID ID) { + switch (ID) { + case Intrinsic::masked_load: + case Intrinsic::masked_store: + return true; + } + return false; + } + static bool isHandledNonTargetIntrinsic(const Value *V) { + if (auto *II = dyn_cast(V)) + return isHandledNonTargetIntrinsic(II->getIntrinsicID()); + return false; + } + bool processNode(DomTreeNode *Node); bool handleBranchCondition(Instruction *CondInst, const BranchInst *BI, @@ -796,14 +832,30 @@ private: Value *getMatchingValue(LoadValue &InVal, ParseMemoryInst &MemInst, unsigned CurrentGeneration); + bool overridingStores(const ParseMemoryInst &Earlier, + const ParseMemoryInst &Later); + Value *getOrCreateResult(Value *Inst, Type *ExpectedType) const { if (auto *LI = dyn_cast(Inst)) return LI; if (auto *SI = dyn_cast(Inst)) return SI->getValueOperand(); assert(isa(Inst) && "Instruction not supported"); - return TTI.getOrCreateResultFromMemIntrinsic(cast(Inst), - ExpectedType); + auto *II = cast(Inst); + if (isHandledNonTargetIntrinsic(II->getIntrinsicID())) + return getOrCreateResultNonTargetMemIntrinsic(II, ExpectedType); + return TTI.getOrCreateResultFromMemIntrinsic(II, ExpectedType); + } + + Value *getOrCreateResultNonTargetMemIntrinsic(IntrinsicInst *II, + Type *ExpectedType) const { + switch (II->getIntrinsicID()) { + case Intrinsic::masked_load: + return II; + case Intrinsic::masked_store: + return II->getOperand(0); + } + return nullptr; } /// Return true if the instruction is known to only operate on memory @@ -813,6 +865,101 @@ private: bool isSameMemGeneration(unsigned EarlierGeneration, unsigned LaterGeneration, Instruction *EarlierInst, Instruction *LaterInst); + bool isNonTargetIntrinsicMatch(const IntrinsicInst *Earlier, + const IntrinsicInst *Later) { + auto IsSubmask = [](const Value *Mask0, const Value *Mask1) { + // Is Mask0 a submask of Mask1? + if (Mask0 == Mask1) + return true; + if (isa(Mask0) || isa(Mask1)) + return false; + auto *Vec0 = dyn_cast(Mask0); + auto *Vec1 = dyn_cast(Mask1); + if (!Vec0 || !Vec1) + return false; + assert(Vec0->getType() == Vec1->getType() && + "Masks should have the same type"); + for (int i = 0, e = Vec0->getNumOperands(); i != e; ++i) { + Constant *Elem0 = Vec0->getOperand(i); + Constant *Elem1 = Vec1->getOperand(i); + auto *Int0 = dyn_cast(Elem0); + if (Int0 && Int0->isZero()) + continue; + auto *Int1 = dyn_cast(Elem1); + if (Int1 && !Int1->isZero()) + continue; + if (isa(Elem0) || isa(Elem1)) + return false; + if (Elem0 == Elem1) + continue; + return false; + } + return true; + }; + auto PtrOp = [](const IntrinsicInst *II) { + if (II->getIntrinsicID() == Intrinsic::masked_load) + return II->getOperand(0); + if (II->getIntrinsicID() == Intrinsic::masked_store) + return II->getOperand(1); + llvm_unreachable("Unexpected IntrinsicInst"); + }; + auto MaskOp = [](const IntrinsicInst *II) { + if (II->getIntrinsicID() == Intrinsic::masked_load) + return II->getOperand(2); + if (II->getIntrinsicID() == Intrinsic::masked_store) + return II->getOperand(3); + llvm_unreachable("Unexpected IntrinsicInst"); + }; + auto ThruOp = [](const IntrinsicInst *II) { + if (II->getIntrinsicID() == Intrinsic::masked_load) + return II->getOperand(3); + llvm_unreachable("Unexpected IntrinsicInst"); + }; + + if (PtrOp(Earlier) != PtrOp(Later)) + return false; + + Intrinsic::ID IDE = Earlier->getIntrinsicID(); + Intrinsic::ID IDL = Later->getIntrinsicID(); + // We could really use specific intrinsic classes for masked loads + // and stores in IntrinsicInst.h. + if (IDE == Intrinsic::masked_load && IDL == Intrinsic::masked_load) { + // Trying to replace later masked load with the earlier one. + // Check that the pointers are the same, and + // - masks and pass-throughs are the same, or + // - replacee's pass-through is "undef" and replacer's mask is a + // super-set of the replacee's mask. + if (MaskOp(Earlier) == MaskOp(Later) && ThruOp(Earlier) == ThruOp(Later)) + return true; + if (!isa(ThruOp(Later))) + return false; + return IsSubmask(MaskOp(Later), MaskOp(Earlier)); + } + if (IDE == Intrinsic::masked_store && IDL == Intrinsic::masked_load) { + // Trying to replace a load of a stored value with the store's value. + // Check that the pointers are the same, and + // - load's mask is a subset of store's mask, and + // - load's pass-through is "undef". + if (!IsSubmask(MaskOp(Later), MaskOp(Earlier))) + return false; + return isa(ThruOp(Later)); + } + if (IDE == Intrinsic::masked_load && IDL == Intrinsic::masked_store) { + // Trying to remove a store of the loaded value. + // Check that the pointers are the same, and + // - store's mask is a subset of the load's mask. + return IsSubmask(MaskOp(Later), MaskOp(Earlier)); + } + if (IDE == Intrinsic::masked_store && IDL == Intrinsic::masked_store) { + // Trying to remove a dead store (earlier). + // Check that the pointers are the same, + // - the to-be-removed store's mask is a subset of the other store's + // mask. + return IsSubmask(MaskOp(Earlier), MaskOp(Later)); + } + return false; + } + void removeMSSA(Instruction &Inst) { if (!MSSA) return; @@ -978,6 +1125,17 @@ Value *EarlyCSE::getMatchingValue(LoadValue &InVal, ParseMemoryInst &MemInst, Instruction *Matching = MemInstMatching ? MemInst.get() : InVal.DefInst; Instruction *Other = MemInstMatching ? InVal.DefInst : MemInst.get(); + // Deal with non-target memory intrinsics. + bool MatchingNTI = isHandledNonTargetIntrinsic(Matching); + bool OtherNTI = isHandledNonTargetIntrinsic(Other); + if (OtherNTI != MatchingNTI) + return nullptr; + if (OtherNTI && MatchingNTI) { + if (!isNonTargetIntrinsicMatch(cast(InVal.DefInst), + cast(MemInst.get()))) + return nullptr; + } + if (!isOperatingOnInvariantMemAt(MemInst.get(), InVal.Generation) && !isSameMemGeneration(InVal.Generation, CurrentGeneration, InVal.DefInst, MemInst.get())) @@ -985,6 +1143,37 @@ Value *EarlyCSE::getMatchingValue(LoadValue &InVal, ParseMemoryInst &MemInst, return getOrCreateResult(Matching, Other->getType()); } +bool EarlyCSE::overridingStores(const ParseMemoryInst &Earlier, + const ParseMemoryInst &Later) { + // Can we remove Earlier store because of Later store? + + assert(Earlier.isUnordered() && !Earlier.isVolatile() && + "Violated invariant"); + if (Earlier.getPointerOperand() != Later.getPointerOperand()) + return false; + if (Earlier.getMatchingId() != Later.getMatchingId()) + return false; + // At the moment, we don't remove ordered stores, but do remove + // unordered atomic stores. There's no special requirement (for + // unordered atomics) about removing atomic stores only in favor of + // other atomic stores since we were going to execute the non-atomic + // one anyway and the atomic one might never have become visible. + if (!Earlier.isUnordered() || !Later.isUnordered()) + return false; + + // Deal with non-target memory intrinsics. + bool ENTI = isHandledNonTargetIntrinsic(Earlier.get()); + bool LNTI = isHandledNonTargetIntrinsic(Later.get()); + if (ENTI && LNTI) + return isNonTargetIntrinsicMatch(cast(Earlier.get()), + cast(Later.get())); + + // Because of the check above, at least one of them is false. + // For now disallow matching intrinsics with non-intrinsics, + // so assume that the stores match if neither is an intrinsic. + return ENTI == LNTI; +} + bool EarlyCSE::processNode(DomTreeNode *Node) { bool Changed = false; BasicBlock *BB = Node->getBlock(); @@ -1320,17 +1509,8 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { if (MemInst.isValid() && MemInst.isStore()) { // We do a trivial form of DSE if there are two stores to the same // location with no intervening loads. Delete the earlier store. - // At the moment, we don't remove ordered stores, but do remove - // unordered atomic stores. There's no special requirement (for - // unordered atomics) about removing atomic stores only in favor of - // other atomic stores since we were going to execute the non-atomic - // one anyway and the atomic one might never have become visible. if (LastStore) { - ParseMemoryInst LastStoreMemInst(LastStore, TTI); - assert(LastStoreMemInst.isUnordered() && - !LastStoreMemInst.isVolatile() && - "Violated invariant"); - if (LastStoreMemInst.isMatchingMemLoc(MemInst)) { + if (overridingStores(ParseMemoryInst(LastStore, TTI), MemInst)) { LLVM_DEBUG(dbgs() << "EarlyCSE DEAD STORE: " << *LastStore << " due to: " << Inst << '\n'); if (!DebugCounter::shouldExecute(CSECounter)) { diff --git a/llvm/test/Transforms/EarlyCSE/masked-intrinsics-unequal-masks.ll b/llvm/test/Transforms/EarlyCSE/masked-intrinsics-unequal-masks.ll index 31c250a..cf5641d 100644 --- a/llvm/test/Transforms/EarlyCSE/masked-intrinsics-unequal-masks.ll +++ b/llvm/test/Transforms/EarlyCSE/masked-intrinsics-unequal-masks.ll @@ -13,8 +13,7 @@ define <4 x i32> @f3(<4 x i32>* %a0, <4 x i32> %a1) { ; CHECK-LABEL: @f3( ; CHECK-NEXT: [[V0:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* [[A0:%.*]], i32 4, <4 x i1> , <4 x i32> [[A1:%.*]]) -; CHECK-NEXT: [[V1:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* [[A0]], i32 4, <4 x i1> , <4 x i32> undef) -; CHECK-NEXT: [[V2:%.*]] = add <4 x i32> [[V0]], [[V1]] +; CHECK-NEXT: [[V2:%.*]] = add <4 x i32> [[V0]], [[V0]] ; CHECK-NEXT: ret <4 x i32> [[V2]] ; %v0 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %a0, i32 4, <4 x i1> , <4 x i32> %a1) @@ -60,8 +59,7 @@ define <4 x i32> @f5(<4 x i32>* %a0, <4 x i32> %a1) { ; Expect the first store to be removed. define void @f6(<4 x i32> %a0, <4 x i32>* %a1) { ; CHECK-LABEL: @f6( -; CHECK-NEXT: call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> [[A0:%.*]], <4 x i32>* [[A1:%.*]], i32 4, <4 x i1> ) -; CHECK-NEXT: call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> [[A0]], <4 x i32>* [[A1]], i32 4, <4 x i1> ) +; CHECK-NEXT: call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> [[A0:%.*]], <4 x i32>* [[A1:%.*]], i32 4, <4 x i1> ) ; CHECK-NEXT: ret void ; call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> %a0, <4 x i32>* %a1, i32 4, <4 x i1> ) @@ -90,7 +88,6 @@ define void @f7(<4 x i32> %a0, <4 x i32>* %a1) { define <4 x i32> @f8(<4 x i32>* %a0, <4 x i32> %a1) { ; CHECK-LABEL: @f8( ; CHECK-NEXT: [[V0:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* [[A0:%.*]], i32 4, <4 x i1> , <4 x i32> [[A1:%.*]]) -; CHECK-NEXT: call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> [[V0]], <4 x i32>* [[A0]], i32 4, <4 x i1> ) ; CHECK-NEXT: ret <4 x i32> [[V0]] ; %v0 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %a0, i32 4, <4 x i1> , <4 x i32> %a1) @@ -119,8 +116,7 @@ define <4 x i32> @f9(<4 x i32>* %a0, <4 x i32> %a1) { define <4 x i32> @fa(<4 x i32> %a0, <4 x i32>* %a1) { ; CHECK-LABEL: @fa( ; CHECK-NEXT: call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> [[A0:%.*]], <4 x i32>* [[A1:%.*]], i32 4, <4 x i1> ) -; CHECK-NEXT: [[V0:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* [[A1]], i32 4, <4 x i1> , <4 x i32> undef) -; CHECK-NEXT: ret <4 x i32> [[V0]] +; CHECK-NEXT: ret <4 x i32> [[A0]] ; call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> %a0, <4 x i32>* %a1, i32 4, <4 x i1> ) %v0 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %a1, i32 4, <4 x i1> , <4 x i32> undef) diff --git a/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll b/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll index 77183ab..392a487 100644 --- a/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll +++ b/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll @@ -5,8 +5,7 @@ define <128 x i8> @f0(<128 x i8>* %a0, <128 x i8> %a1, <128 x i8> %a2) { ; CHECK-LABEL: @f0( ; CHECK-NEXT: [[V0:%.*]] = icmp eq <128 x i8> [[A1:%.*]], [[A2:%.*]] ; CHECK-NEXT: call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> [[A1]], <128 x i8>* [[A0:%.*]], i32 4, <128 x i1> [[V0]]) -; CHECK-NEXT: [[V1:%.*]] = call <128 x i8> @llvm.masked.load.v128i8.p0v128i8(<128 x i8>* [[A0]], i32 4, <128 x i1> [[V0]], <128 x i8> undef) -; CHECK-NEXT: ret <128 x i8> [[V1]] +; CHECK-NEXT: ret <128 x i8> [[A1]] ; %v0 = icmp eq <128 x i8> %a1, %a2 call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> %a1, <128 x i8>* %a0, i32 4, <128 x i1> %v0) @@ -18,7 +17,6 @@ define <128 x i8> @f1(<128 x i8>* %a0, <128 x i8> %a1, <128 x i8> %a2) { ; CHECK-LABEL: @f1( ; CHECK-NEXT: [[V0:%.*]] = icmp eq <128 x i8> [[A1:%.*]], [[A2:%.*]] ; CHECK-NEXT: [[V1:%.*]] = call <128 x i8> @llvm.masked.load.v128i8.p0v128i8(<128 x i8>* [[A0:%.*]], i32 4, <128 x i1> [[V0]], <128 x i8> undef) -; CHECK-NEXT: call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> [[V1]], <128 x i8>* [[A0]], i32 4, <128 x i1> [[V0]]) ; CHECK-NEXT: ret <128 x i8> [[V1]] ; %v0 = icmp eq <128 x i8> %a1, %a2 -- 2.7.4