[mlir] Refactor InterfaceMap to use a sorted vector of interfaces, as opposed to...
authorRiver Riddle <riddleriver@gmail.com>
Tue, 23 Feb 2021 22:22:23 +0000 (14:22 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 23 Feb 2021 22:36:45 +0000 (14:36 -0800)
A majority of operations have a very small number of interfaces, which means that the cost of using a hash map is generally larger for interface lookups than just a binary search. In the future when there are a number of operations with large amounts of interfaces, we can switch to a hybrid approach that optimizes lookups based on the number of interfaces. For now, however, a binary search is the best approach.

This dropped compile time on a largish TF MLIR module by 20%(half a second).

Differential Revision: https://reviews.llvm.org/D96085

mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/Support/InterfaceSupport.h
mlir/lib/IR/Operation.cpp

index b6283fc..6dcd716 100644 (file)
@@ -327,7 +327,9 @@ public:
 
   /// If this operation has a registered operation description, return it.
   /// Otherwise return null.
-  const AbstractOperation *getAbstractOperation() const;
+  const AbstractOperation *getAbstractOperation() const {
+    return representation.dyn_cast<const AbstractOperation *>();
+  }
 
   void print(raw_ostream &os) const;
   void dump() const;
index 44b0f67..b618e8e 100644 (file)
@@ -152,10 +152,8 @@ class InterfaceMap {
 public:
   InterfaceMap(InterfaceMap &&) = default;
   ~InterfaceMap() {
-    if (interfaces) {
-      for (auto &it : *interfaces)
-        free(it.second);
-    }
+    for (auto &it : interfaces)
+      free(it.second);
   }
 
   /// Construct an InterfaceMap with the given set of template types. For
@@ -182,15 +180,22 @@ public:
   /// Returns an instance of the concept object for the given interface if it
   /// was registered to this map, null otherwise.
   template <typename T> typename T::Concept *lookup() const {
-    void *inst = interfaces ? interfaces->lookup(T::getInterfaceID()) : nullptr;
-    return reinterpret_cast<typename T::Concept *>(inst);
+    return reinterpret_cast<typename T::Concept *>(lookup(T::getInterfaceID()));
   }
 
 private:
+  /// Compare two TypeID instances by comparing the underlying pointer.
+  static bool compare(TypeID lhs, TypeID rhs) {
+    return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
+  }
+
   InterfaceMap() = default;
   InterfaceMap(MutableArrayRef<std::pair<TypeID, void *>> elements)
-      : interfaces(std::make_unique<llvm::SmallDenseMap<TypeID, void *>>(
-            elements.begin(), elements.end())) {}
+      : interfaces(elements.begin(), elements.end()) {
+    llvm::sort(interfaces, [](const auto &lhs, const auto &rhs) {
+      return compare(lhs.first, rhs.first);
+    });
+  }
 
   template <typename... Ts>
   static InterfaceMap getImpl(std::tuple<Ts...> *) {
@@ -200,9 +205,17 @@ private:
     return InterfaceMap(elements);
   }
 
-  /// The internal map of interfaces. This is constructed statically for each
-  /// set of interfaces.
-  std::unique_ptr<llvm::SmallDenseMap<TypeID, void *>> interfaces;
+  /// Returns an instance of the concept object for the given interface id if it
+  /// was registered to this map, null otherwise.
+  void *lookup(TypeID id) const {
+    auto it = llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) {
+      return compare(it.first, id);
+    });
+    return (it != interfaces.end() && it->first == id) ? it->second : nullptr;
+  }
+
+  /// A list of interface instances, sorted by TypeID.
+  SmallVector<std::pair<TypeID, void *>> interfaces;
 };
 
 } // end namespace detail
index f1c40b0..9349e4c 100644 (file)
@@ -57,10 +57,6 @@ Identifier OperationName::getIdentifier() const {
   return representation.get<Identifier>();
 }
 
-const AbstractOperation *OperationName::getAbstractOperation() const {
-  return representation.dyn_cast<const AbstractOperation *>();
-}
-
 OperationName OperationName::getFromOpaquePointer(const void *pointer) {
   return OperationName(
       RepresentationUnion::getFromOpaqueValue(const_cast<void *>(pointer)));