From 9b4984926c74e79a9152f2a97e0b733ac797d282 Mon Sep 17 00:00:00 2001 From: Daniel Berlin Date: Sat, 1 Apr 2017 09:44:29 +0000 Subject: [PATCH] NewGVN: Clean up GVNExpression memory hierarchy, restructure hash computation a bit so we don't have to redefine it for loads, stores, and calls llvm-svn: 299299 --- .../include/llvm/Transforms/Scalar/GVNExpression.h | 126 +++++++++++---------- llvm/lib/Transforms/Scalar/NewGVN.cpp | 24 ++-- 2 files changed, 75 insertions(+), 75 deletions(-) diff --git a/llvm/include/llvm/Transforms/Scalar/GVNExpression.h b/llvm/include/llvm/Transforms/Scalar/GVNExpression.h index ad5bb40..142f231 100644 --- a/llvm/include/llvm/Transforms/Scalar/GVNExpression.h +++ b/llvm/include/llvm/Transforms/Scalar/GVNExpression.h @@ -43,11 +43,13 @@ enum ExpressionType { ET_Unknown, ET_BasicStart, ET_Basic, - ET_Call, ET_AggregateValue, ET_Phi, + ET_MemoryStart, + ET_Call, ET_Load, ET_Store, + ET_MemoryEnd, ET_BasicEnd }; @@ -72,8 +74,6 @@ public: if (getOpcode() == getEmptyKey() || getOpcode() == getTombstoneKey()) return true; // Compare the expression type for anything but load and store. - // For load and store we set the opcode to zero. - // This is needed for load coercion. if (getExpressionType() != ET_Load && getExpressionType() != ET_Store && getExpressionType() != Other.getExpressionType()) return false; @@ -87,9 +87,8 @@ public: void setOpcode(unsigned opcode) { Opcode = opcode; } ExpressionType getExpressionType() const { return EType; } - virtual hash_code getHashValue() const { - return hash_combine(getExpressionType(), getOpcode()); - } + // We deliberately leave the expression type out of the hash value. + virtual hash_code getHashValue() const { return getOpcode(); } // // Debugging support @@ -106,7 +105,10 @@ public: OS << "}"; } - void dump() const { print(dbgs()); } + LLVM_DUMP_METHOD void dump() const { + print(dbgs()); + dbgs() << "\n"; + } }; inline raw_ostream &operator<<(raw_ostream &OS, const Expression &E) { @@ -200,7 +202,7 @@ public: } hash_code getHashValue() const override { - return hash_combine(getExpressionType(), getOpcode(), ValueType, + return hash_combine(this->Expression::getHashValue(), ValueType, hash_combine_range(op_begin(), op_end())); } @@ -241,32 +243,53 @@ public: op_inserter &operator++(int) { return *this; } }; -class CallExpression final : public BasicExpression { +class MemoryExpression : public BasicExpression { private: - CallInst *Call; - MemoryAccess *DefiningAccess; + const MemoryAccess *MemoryLeader; public: - CallExpression(unsigned NumOperands, CallInst *C, MemoryAccess *DA) - : BasicExpression(NumOperands, ET_Call), Call(C), DefiningAccess(DA) {} - CallExpression() = delete; - CallExpression(const CallExpression &) = delete; - CallExpression &operator=(const CallExpression &) = delete; - ~CallExpression() override; + MemoryExpression(unsigned NumOperands, enum ExpressionType EType, + const MemoryAccess *MemoryLeader) + : BasicExpression(NumOperands, EType), MemoryLeader(MemoryLeader){}; + MemoryExpression() = delete; + MemoryExpression(const MemoryExpression &) = delete; + MemoryExpression &operator=(const MemoryExpression &) = delete; static bool classof(const Expression *EB) { - return EB->getExpressionType() == ET_Call; + return EB->getExpressionType() > ET_MemoryStart && + EB->getExpressionType() < ET_MemoryEnd; + } + hash_code getHashValue() const override { + return hash_combine(this->BasicExpression::getHashValue(), MemoryLeader); } bool equals(const Expression &Other) const override { if (!this->BasicExpression::equals(Other)) return false; - const auto &OE = cast(Other); - return DefiningAccess == OE.DefiningAccess; + const MemoryExpression &OtherMCE = cast(Other); + + return MemoryLeader == OtherMCE.MemoryLeader; } - hash_code getHashValue() const override { - return hash_combine(this->BasicExpression::getHashValue(), DefiningAccess); + const MemoryAccess *getMemoryLeader() const { return MemoryLeader; } + void setMemoryLeader(const MemoryAccess *ML) { MemoryLeader = ML; } +}; + +class CallExpression final : public MemoryExpression { +private: + CallInst *Call; + +public: + CallExpression(unsigned NumOperands, CallInst *C, + const MemoryAccess *MemoryLeader) + : MemoryExpression(NumOperands, ET_Call, MemoryLeader), Call(C) {} + CallExpression() = delete; + CallExpression(const CallExpression &) = delete; + CallExpression &operator=(const CallExpression &) = delete; + ~CallExpression() override; + + static bool classof(const Expression *EB) { + return EB->getExpressionType() == ET_Call; } // @@ -276,22 +299,23 @@ public: if (PrintEType) OS << "ExpressionTypeCall, "; this->BasicExpression::printInternal(OS, false); - OS << " represents call at " << Call; + OS << " represents call at "; + Call->printAsOperand(OS); } }; -class LoadExpression final : public BasicExpression { +class LoadExpression final : public MemoryExpression { private: LoadInst *Load; - MemoryAccess *DefiningAccess; unsigned Alignment; public: - LoadExpression(unsigned NumOperands, LoadInst *L, MemoryAccess *DA) - : LoadExpression(ET_Load, NumOperands, L, DA) {} + LoadExpression(unsigned NumOperands, LoadInst *L, + const MemoryAccess *MemoryLeader) + : LoadExpression(ET_Load, NumOperands, L, MemoryLeader) {} LoadExpression(enum ExpressionType EType, unsigned NumOperands, LoadInst *L, - MemoryAccess *DA) - : BasicExpression(NumOperands, EType), Load(L), DefiningAccess(DA) { + const MemoryAccess *MemoryLeader) + : MemoryExpression(NumOperands, EType, MemoryLeader), Load(L) { Alignment = L ? L->getAlignment() : 0; } LoadExpression() = delete; @@ -306,18 +330,11 @@ public: LoadInst *getLoadInst() const { return Load; } void setLoadInst(LoadInst *L) { Load = L; } - MemoryAccess *getDefiningAccess() const { return DefiningAccess; } - void setDefiningAccess(MemoryAccess *MA) { DefiningAccess = MA; } unsigned getAlignment() const { return Alignment; } void setAlignment(unsigned Align) { Alignment = Align; } bool equals(const Expression &Other) const override; - hash_code getHashValue() const override { - return hash_combine(getOpcode(), getType(), DefiningAccess, - hash_combine_range(op_begin(), op_end())); - } - // // Debugging support // @@ -325,22 +342,22 @@ public: if (PrintEType) OS << "ExpressionTypeLoad, "; this->BasicExpression::printInternal(OS, false); - OS << " represents Load at " << Load; - OS << " with DefiningAccess " << *DefiningAccess; + OS << " represents Load at "; + Load->printAsOperand(OS); + OS << " with MemoryLeader " << *getMemoryLeader(); } }; -class StoreExpression final : public BasicExpression { +class StoreExpression final : public MemoryExpression { private: StoreInst *Store; Value *StoredValue; - MemoryAccess *DefiningAccess; public: StoreExpression(unsigned NumOperands, StoreInst *S, Value *StoredValue, - MemoryAccess *DA) - : BasicExpression(NumOperands, ET_Store), Store(S), - StoredValue(StoredValue), DefiningAccess(DA) {} + const MemoryAccess *MemoryLeader) + : MemoryExpression(NumOperands, ET_Store, MemoryLeader), Store(S), + StoredValue(StoredValue) {} StoreExpression() = delete; StoreExpression(const StoreExpression &) = delete; StoreExpression &operator=(const StoreExpression &) = delete; @@ -351,27 +368,18 @@ public: } StoreInst *getStoreInst() const { return Store; } - MemoryAccess *getDefiningAccess() const { return DefiningAccess; } Value *getStoredValue() const { return StoredValue; } bool equals(const Expression &Other) const override; - hash_code getHashValue() const override { - // This deliberately does not include the stored value we compare it as part - // of equals, and only against other stores. - return hash_combine(getOpcode(), getType(), DefiningAccess, - hash_combine_range(op_begin(), op_end())); - } - - // // Debugging support // void printInternal(raw_ostream &OS, bool PrintEType) const override { if (PrintEType) OS << "ExpressionTypeStore, "; this->BasicExpression::printInternal(OS, false); - OS << " represents Store at " << Store; - OS << " with DefiningAccess " << *DefiningAccess; + OS << " represents Store " << *Store; + OS << " with MemoryLeader " << *getMemoryLeader(); } }; @@ -527,8 +535,8 @@ public: } hash_code getHashValue() const override { - return hash_combine(getExpressionType(), VariableValue->getType(), - VariableValue); + return hash_combine(this->Expression::getHashValue(), + VariableValue->getType(), VariableValue); } // @@ -566,8 +574,8 @@ public: } hash_code getHashValue() const override { - return hash_combine(getExpressionType(), ConstantValue->getType(), - ConstantValue); + return hash_combine(this->Expression::getHashValue(), + ConstantValue->getType(), ConstantValue); } // @@ -604,7 +612,7 @@ public: } hash_code getHashValue() const override { - return hash_combine(getExpressionType(), Inst); + return hash_combine(this->Expression::getHashValue(), Inst); } // diff --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp index 34efa78..a5cc160 100644 --- a/llvm/lib/Transforms/Scalar/NewGVN.cpp +++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp @@ -429,17 +429,9 @@ private: template static bool equalsLoadStoreHelper(const T &LHS, const Expression &RHS) { - if ((!isa(RHS) && !isa(RHS)) || - !LHS.BasicExpression::equals(RHS)) { + if (!isa(RHS) && !isa(RHS)) return false; - } else if (const auto *L = dyn_cast(&RHS)) { - if (LHS.getDefiningAccess() != L->getDefiningAccess()) - return false; - } else if (const auto *S = dyn_cast(&RHS)) { - if (LHS.getDefiningAccess() != S->getDefiningAccess()) - return false; - } - return true; + return LHS.MemoryExpression::equals(RHS); } bool LoadExpression::equals(const Expression &Other) const { @@ -447,13 +439,13 @@ bool LoadExpression::equals(const Expression &Other) const { } bool StoreExpression::equals(const Expression &Other) const { - bool Result = equalsLoadStoreHelper(*this, Other); + if (!equalsLoadStoreHelper(*this, Other)) + return false; // Make sure that store vs store includes the value operand. - if (Result) - if (const auto *S = dyn_cast(&Other)) - if (getStoredValue() != S->getStoredValue()) - return false; - return Result; + if (const auto *S = dyn_cast(&Other)) + if (getStoredValue() != S->getStoredValue()) + return false; + return true; } #ifndef NDEBUG -- 2.7.4