From 65a3197a8fa2e5d1deb8707bda13ebd21e1dedb3 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 23 Feb 2021 14:22:23 -0800 Subject: [PATCH] [mlir] Refactor InterfaceMap to use a sorted vector of interfaces, as opposed to a DenseMap A majority of operations have a very small number of interfaces, which means that the cost of using a hash map is generally larger for interface lookups than just a binary search. In the future when there are a number of operations with large amounts of interfaces, we can switch to a hybrid approach that optimizes lookups based on the number of interfaces. For now, however, a binary search is the best approach. This dropped compile time on a largish TF MLIR module by 20%(half a second). Differential Revision: https://reviews.llvm.org/D96085 --- mlir/include/mlir/IR/OperationSupport.h | 4 +++- mlir/include/mlir/Support/InterfaceSupport.h | 35 +++++++++++++++++++--------- mlir/lib/IR/Operation.cpp | 4 ---- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index b6283fc..6dcd716 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -327,7 +327,9 @@ public: /// If this operation has a registered operation description, return it. /// Otherwise return null. - const AbstractOperation *getAbstractOperation() const; + const AbstractOperation *getAbstractOperation() const { + return representation.dyn_cast(); + } void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h index 44b0f67..b618e8e 100644 --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -152,10 +152,8 @@ class InterfaceMap { public: InterfaceMap(InterfaceMap &&) = default; ~InterfaceMap() { - if (interfaces) { - for (auto &it : *interfaces) - free(it.second); - } + for (auto &it : interfaces) + free(it.second); } /// Construct an InterfaceMap with the given set of template types. For @@ -182,15 +180,22 @@ public: /// Returns an instance of the concept object for the given interface if it /// was registered to this map, null otherwise. template typename T::Concept *lookup() const { - void *inst = interfaces ? interfaces->lookup(T::getInterfaceID()) : nullptr; - return reinterpret_cast(inst); + return reinterpret_cast(lookup(T::getInterfaceID())); } private: + /// Compare two TypeID instances by comparing the underlying pointer. + static bool compare(TypeID lhs, TypeID rhs) { + return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer(); + } + InterfaceMap() = default; InterfaceMap(MutableArrayRef> elements) - : interfaces(std::make_unique>( - elements.begin(), elements.end())) {} + : interfaces(elements.begin(), elements.end()) { + llvm::sort(interfaces, [](const auto &lhs, const auto &rhs) { + return compare(lhs.first, rhs.first); + }); + } template static InterfaceMap getImpl(std::tuple *) { @@ -200,9 +205,17 @@ private: return InterfaceMap(elements); } - /// The internal map of interfaces. This is constructed statically for each - /// set of interfaces. - std::unique_ptr> interfaces; + /// Returns an instance of the concept object for the given interface id if it + /// was registered to this map, null otherwise. + void *lookup(TypeID id) const { + auto it = llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) { + return compare(it.first, id); + }); + return (it != interfaces.end() && it->first == id) ? it->second : nullptr; + } + + /// A list of interface instances, sorted by TypeID. + SmallVector> interfaces; }; } // end namespace detail diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index f1c40b0..9349e4c 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -57,10 +57,6 @@ Identifier OperationName::getIdentifier() const { return representation.get(); } -const AbstractOperation *OperationName::getAbstractOperation() const { - return representation.dyn_cast(); -} - OperationName OperationName::getFromOpaquePointer(const void *pointer) { return OperationName( RepresentationUnion::getFromOpaqueValue(const_cast(pointer))); -- 2.7.4