From 55e02afc5a349139654c883842fd7b2d46f7d0a6 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 19 Jun 2020 07:35:44 -0700 Subject: [PATCH] [Object] Introduce POD-C Compliant tvm::Map (#5740) --- 3rdparty/compiler-rt/builtin_fp16.h | 6 +- CMakeLists.txt | 11 + apps/android_deploy/app/src/main/jni/Android.mk | 1 - apps/android_rpc/app/src/main/jni/Android.mk | 1 - apps/howto_deploy/Makefile | 2 +- cmake/config.cmake | 3 + include/tvm/node/container.h | 1406 +++++++++++++++++++++-- include/tvm/runtime/container.h | 51 +- include/tvm/runtime/object.h | 3 + src/ir/expr.cc | 4 +- src/ir/module.cc | 4 +- src/node/container.cc | 58 +- src/node/reflection.cc | 2 +- src/node/serialization.cc | 224 +++- src/runtime/object_internal.h | 10 + src/tir/ir/transform.cc | 4 +- src/tir/transforms/split_host_device.cc | 2 +- tests/cpp/container_test.cc | 49 +- 18 files changed, 1611 insertions(+), 230 deletions(-) diff --git a/3rdparty/compiler-rt/builtin_fp16.h b/3rdparty/compiler-rt/builtin_fp16.h index 8048980..bd2e677 100644 --- a/3rdparty/compiler-rt/builtin_fp16.h +++ b/3rdparty/compiler-rt/builtin_fp16.h @@ -18,9 +18,11 @@ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. - * \file builtin_fp16.cc + * \file builtin_fp16.h * \brief Functions for conversion between fp32 and fp16, adopted from compiler-rt. */ +#ifndef COMPILER_RT_BUILTIN_FP16_H_ +#define COMPILER_RT_BUILTIN_FP16_H_ #include @@ -236,3 +238,5 @@ static inline DST_T __extendXfYf2__(SRC_T a) { dst_rep.i = result; return dst_rep.f; } + +#endif // COMPILER_RT_BUILTIN_FP16_H_ diff --git a/CMakeLists.txt b/CMakeLists.txt index 18f58c8..d7faa8a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,6 +43,7 @@ tvm_option(USE_MICRO "Build with Micro" OFF) tvm_option(INSTALL_DEV "Install compiler infrastructure" OFF) tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF) tvm_option(USE_TF_TVMDSOOP "Build with TensorFlow TVMDSOOp" OFF) +tvm_option(USE_FALLBACK_STL_MAP "Use TVM's POD compatible Map" OFF) # 3rdparty libraries tvm_option(DLPACK_PATH "Path to DLPACK" "3rdparty/dlpack/include") @@ -353,6 +354,16 @@ else() set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "NDEBUG") endif(USE_RELAY_DEBUG) +if(USE_FALLBACK_STL_MAP) + message(STATUS "Building with STL Map...") + set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "USE_FALLBACK_STL_MAP=1") + set_target_properties(tvm_topi PROPERTIES COMPILE_DEFINITIONS "USE_FALLBACK_STL_MAP=1") +else() + message(STATUS "Building with TVM Map...") + set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "USE_FALLBACK_STL_MAP=0") + set_target_properties(tvm_topi PROPERTIES COMPILE_DEFINITIONS "USE_FALLBACK_STL_MAP=0") +endif(USE_FALLBACK_STL_MAP) + if(BUILD_FOR_HEXAGON) # Wrap pthread_create to allow setting custom stack size. set_target_properties(tvm_runtime PROPERTIES LINK_FLAGS diff --git a/apps/android_deploy/app/src/main/jni/Android.mk b/apps/android_deploy/app/src/main/jni/Android.mk index 3a83d8a..58f82f9 100644 --- a/apps/android_deploy/app/src/main/jni/Android.mk +++ b/apps/android_deploy/app/src/main/jni/Android.mk @@ -39,7 +39,6 @@ LOCAL_LDFLAGS := -L$(SYSROOT)/usr/lib/ -llog LOCAL_C_INCLUDES := $(ROOT_PATH)/include \ $(ROOT_PATH)/3rdparty/dlpack/include \ $(ROOT_PATH)/3rdparty/dmlc-core/include \ - $(ROOT_PATH)/3rdparty/HalideIR/src \ $(ROOT_PATH)/topi/include LOCAL_MODULE = tvm4j_runtime_packed diff --git a/apps/android_rpc/app/src/main/jni/Android.mk b/apps/android_rpc/app/src/main/jni/Android.mk index 3a83d8a..58f82f9 100644 --- a/apps/android_rpc/app/src/main/jni/Android.mk +++ b/apps/android_rpc/app/src/main/jni/Android.mk @@ -39,7 +39,6 @@ LOCAL_LDFLAGS := -L$(SYSROOT)/usr/lib/ -llog LOCAL_C_INCLUDES := $(ROOT_PATH)/include \ $(ROOT_PATH)/3rdparty/dlpack/include \ $(ROOT_PATH)/3rdparty/dmlc-core/include \ - $(ROOT_PATH)/3rdparty/HalideIR/src \ $(ROOT_PATH)/topi/include LOCAL_MODULE = tvm4j_runtime_packed diff --git a/apps/howto_deploy/Makefile b/apps/howto_deploy/Makefile index 4ee243c..a6d668f 100644 --- a/apps/howto_deploy/Makefile +++ b/apps/howto_deploy/Makefile @@ -22,7 +22,7 @@ DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core PKG_CFLAGS = -std=c++14 -O2 -fPIC\ -I${TVM_ROOT}/include\ -I${DMLC_CORE}/include\ - -I${TVM_ROOT}/3rdparty/dlpack/include\ + -I${TVM_ROOT}/3rdparty/dlpack/include PKG_LDFLAGS = -L${TVM_ROOT}/build -ldl -pthread diff --git a/cmake/config.cmake b/cmake/config.cmake index 7e5734e..1b19692 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -212,6 +212,9 @@ set(USE_THRUST OFF) # Whether to build the TensorFlow TVMDSOOp module set(USE_TF_TVMDSOOP OFF) +# Whether to use STL's std::unordered_map or TVM's POD compatible Map +set(USE_FALLBACK_STL_MAP OFF) + # Whether to use hexagon device set(USE_HEXAGON_DEVICE OFF) set(USE_HEXAGON_SDK /path/to/sdk) diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index fa8c875..365eb60 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -23,17 +23,18 @@ #ifndef TVM_NODE_CONTAINER_H_ #define TVM_NODE_CONTAINER_H_ +#ifndef USE_FALLBACK_STL_MAP +#define USE_FALLBACK_STL_MAP 0 +#endif + #include #include #include #include -#include +#include #include -#include -#include #include -#include namespace tvm { @@ -43,6 +44,8 @@ using runtime::Downcast; using runtime::IterAdapter; using runtime::make_object; using runtime::Object; +using runtime::ObjectEqual; +using runtime::ObjectHash; using runtime::ObjectPtr; using runtime::ObjectPtrEqual; using runtime::ObjectPtrHash; @@ -50,44 +53,1191 @@ using runtime::ObjectRef; using runtime::String; using runtime::StringObj; -/*! \brief String-aware ObjectRef hash functor */ -struct ObjectHash { - size_t operator()(const ObjectRef& a) const { - if (const auto* str = a.as()) { - return String::HashBytes(str->data, str->size); - } - return ObjectPtrHash()(a); +#if (USE_FALLBACK_STL_MAP != 0) + +/*! \brief Shared content of all specializations of hash map */ +class MapNode : public Object { + public: + /*! \brief Type of the keys in the hash map */ + using key_type = ObjectRef; + /*! \brief Type of the values in the hash map */ + using mapped_type = ObjectRef; + /*! \brief Type of the actual underlying container */ + using ContainerType = std::unordered_map; + /*! \brief Iterator class */ + using iterator = ContainerType::iterator; + /*! \brief Iterator class */ + using const_iterator = ContainerType::const_iterator; + /*! \brief Type of value stored in the hash map */ + using KVType = ContainerType::value_type; + + static_assert(std::is_standard_layout::value, "KVType is not standard layout"); + static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect"); + + static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap; + static constexpr const char* _type_key = "Map"; + TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); + + /*! + * \brief Number of elements in the SmallMapNode + * \return The result + */ + size_t size() const { return data_.size(); } + /*! + * \brief Count the number of times a key exists in the hash map + * \param key The indexing key + * \return The result, 0 or 1 + */ + size_t count(const key_type& key) const { return data_.count(key); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const { return data_.at(key); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key) { return data_.at(key); } + /*! \return begin iterator */ + iterator begin() { return data_.begin(); } + /*! \return const begin iterator */ + const_iterator begin() const { return data_.begin(); } + /*! \return end iterator */ + iterator end() { return data_.end(); } + /*! \return end iterator */ + const_iterator end() const { return data_.end(); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + const_iterator find(const key_type& key) const { return data_.find(key); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) { return data_.find(key); } + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position) { data_.erase(position); } + /*! + * \brief Erase the entry associated with the key, do nothing if not exists + * \param key The indexing key + */ + void erase(const key_type& key) { data_.erase(key); } + /*! + * \brief Create an empty container + * \return The object created + */ + static ObjectPtr Empty() { return make_object(); } + + protected: + /*! + * \brief Create the map using contents from the given iterators. + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \return ObjectPtr to the map created + */ + template + static ObjectPtr CreateFromRange(IterType first, IterType last) { + ObjectPtr p = make_object(); + p->data_ = ContainerType(first, last); + return p; + } + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { + MapNode* map_node = static_cast(map->get()); + map_node->data_[kv.first] = kv.second; + } + /*! + * \brief Create an empty container with elements copying from another MapNode + * \param from The source container + * \return The object created + */ + static ObjectPtr CopyFrom(MapNode* from) { + ObjectPtr p = make_object(); + p->data_ = ContainerType(from->data_.begin(), from->data_.end()); + return p; } + /*! \brief The real container storing data */ + ContainerType data_; + template + friend class Map; }; -/*! \brief String-aware ObjectRef equal functor */ -struct ObjectEqual { - bool operator()(const ObjectRef& a, const ObjectRef& b) const { - if (a.same_as(b)) { - return true; +#else + +/*! \brief Shared content of all specializations of hash map */ +class MapNode : public Object { + public: + /*! \brief Type of the keys in the hash map */ + using key_type = ObjectRef; + /*! \brief Type of the values in the hash map */ + using mapped_type = ObjectRef; + /*! \brief Type of value stored in the hash map */ + using KVType = std::pair; + /*! \brief Iterator class */ + class iterator; + + static_assert(std::is_standard_layout::value, "KVType is not standard layout"); + static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect"); + + static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap; + static constexpr const char* _type_key = "Map"; + TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); + + /*! + * \brief Number of elements in the SmallMapNode + * \return The result + */ + size_t size() const { return size_; } + /*! + * \brief Count the number of times a key exists in the hash map + * \param key The indexing key + * \return The result, 0 or 1 + */ + size_t count(const key_type& key) const; + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const; + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key); + /*! \return begin iterator */ + iterator begin() const; + /*! \return end iterator */ + iterator end() const; + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) const; + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position); + /*! + * \brief Erase the entry associated with the key, do nothing if not exists + * \param key The indexing key + */ + void erase(const key_type& key) { erase(find(key)); } + + class iterator { + public: + using iterator_category = std::forward_iterator_tag; + using difference_type = int64_t; + using value_type = KVType; + using pointer = KVType*; + using reference = KVType&; + /*! \brief Default constructor */ + iterator() : index(0), self(nullptr) {} + /*! \brief Compare iterators */ + bool operator==(const iterator& other) const { + return index == other.index && self == other.self; } - if (const auto* str_a = a.as()) { - if (const auto* str_b = b.as()) { - return String::memncmp(str_a->data, str_b->data, str_a->size, str_b->size) == 0; + /*! \brief Compare iterators */ + bool operator!=(const iterator& other) const { return !(*this == other); } + /*! \brief De-reference iterators */ + pointer operator->() const; + /*! \brief De-reference iterators */ + reference operator*() const { return *((*this).operator->()); } + /*! \brief Prefix self increment, e.g. ++iter */ + iterator& operator++(); + /*! \brief Prefix self decrement, e.g. --iter */ + iterator& operator--(); + /*! \brief Suffix self increment */ + iterator operator++(int) { + iterator copy = *this; + ++(*this); + return copy; + } + /*! \brief Suffix self decrement */ + iterator operator--(int) { + iterator copy = *this; + --(*this); + return copy; + } + + protected: + /*! \brief Construct by value */ + iterator(uint64_t index, const MapNode* self) : index(index), self(self) {} + /*! \brief The position on the array */ + uint64_t index; + /*! \brief The container it points to */ + const MapNode* self; + + friend class DenseMapNode; + friend class SmallMapNode; + }; + /*! + * \brief Create an empty container + * \return The object created + */ + static inline ObjectPtr Empty(); + + protected: + /*! + * \brief Create the map using contents from the given iterators. + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \return ObjectPtr to the map created + */ + template + static inline ObjectPtr CreateFromRange(IterType first, IterType last); + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static inline void InsertMaybeReHash(const KVType& kv, ObjectPtr* map); + /*! + * \brief Create an empty container with elements copying from another SmallMapNode + * \param from The source container + * \return The object created + */ + static inline ObjectPtr CopyFrom(MapNode* from); + /*! \brief number of slots minus 1 */ + uint64_t slots_; + /*! \brief number of entries in the container */ + uint64_t size_; + // Reference class + template + friend class Map; +}; + +/*! \brief A specialization of small-sized hash map */ +class SmallMapNode : public MapNode, + public runtime::InplaceArrayBase { + private: + static constexpr uint64_t kInitSize = 2; + static constexpr uint64_t kMaxSize = 4; + + public: + using MapNode::iterator; + using MapNode::KVType; + + /*! \brief Defaults to the destructor of InplaceArrayBase */ + ~SmallMapNode() = default; + /*! + * \brief Count the number of times a key exists in the SmallMapNode + * \param key The indexing key + * \return The result, 0 or 1 + */ + size_t count(const key_type& key) const { return find(key).index < size_; } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const { + iterator itr = find(key); + CHECK(itr.index < size_) << "IndexError: key is not in Map"; + return itr->second; + } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key) { + iterator itr = find(key); + CHECK(itr.index < size_) << "IndexError: key is not in Map"; + return itr->second; + } + /*! \return begin iterator */ + iterator begin() const { return iterator(0, this); } + /*! \return end iterator */ + iterator end() const { return iterator(size_, this); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) const { + KVType* ptr = static_cast(AddressOf(0)); + for (uint64_t i = 0; i < size_; ++i, ++ptr) { + if (ObjectEqual()(ptr->first, key)) { + return iterator(i, this); } } - return false; + return iterator(size_, this); + } + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position) { Erase(position.index); } + + private: + /*! + * \brief Remove a position in SmallMapNode + * \param index The position to be removed + */ + void Erase(const uint64_t index) { + if (index >= size_) { + return; + } + KVType* begin = static_cast(AddressOf(0)); + KVType* last = begin + (size_ - 1); + if (index + 1 == size_) { + last->first.ObjectRef::~ObjectRef(); + last->second.ObjectRef::~ObjectRef(); + } else { + *(begin + index) = std::move(*last); + } + size_ -= 1; + } + /*! + * \brief Create an empty container + * \param n Number of empty slots + * \return The object created + */ + static ObjectPtr Empty(uint64_t n = kInitSize) { + using ::tvm::runtime::make_inplace_array_object; + ObjectPtr p = make_inplace_array_object(n); + p->size_ = 0; + p->slots_ = n; + return p; + } + /*! + * \brief Create an empty container initialized with a given range + * \param n Number of empty slots + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + * \return The object created + */ + template + static ObjectPtr CreateFromRange(uint64_t n, IterType first, IterType last) { + ObjectPtr p = Empty(n); + KVType* ptr = static_cast(p->AddressOf(0)); + for (; first != last; ++first, ++p->size_) { + new (ptr++) KVType(*first); + } + return p; } + /*! + * \brief Create an empty container with elements copying from another SmallMapNode + * \param from The source container + * \return The object created + */ + static ObjectPtr CopyFrom(SmallMapNode* from) { + KVType* first = static_cast(from->AddressOf(0)); + KVType* last = first + from->size_; + return CreateFromRange(from->size_, first, last); + } + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { + SmallMapNode* map_node = static_cast(map->get()); + iterator itr = map_node->find(kv.first); + if (itr.index < map_node->size_) { + itr->second = kv.second; + return; + } + if (map_node->size_ < map_node->slots_) { + KVType* ptr = static_cast(map_node->AddressOf(map_node->size_)); + new (ptr) KVType(kv); + ++map_node->size_; + return; + } + uint64_t next_size = std::max(map_node->slots_ * 2, uint64_t(kInitSize)); + next_size = std::min(next_size, uint64_t(kMaxSize)); + CHECK_GT(next_size, map_node->slots_); + ObjectPtr new_map = CreateFromRange(next_size, map_node->begin(), map_node->end()); + InsertMaybeReHash(kv, &new_map); + *map = std::move(new_map); + } + /*! + * \brief Increment the pointer + * \param index The pointer to be incremented + * \return The increased pointer + */ + uint64_t IncItr(uint64_t index) const { return index + 1 < size_ ? index + 1 : size_; } + /*! + * \brief Decrement the pointer + * \param index The pointer to be decremented + * \return The decreased pointer + */ + uint64_t DecItr(uint64_t index) const { return index > 0 ? index - 1 : size_; } + /*! + * \brief De-reference the pointer + * \param index The pointer to be dereferenced + * \return The result + */ + KVType* DeRefItr(uint64_t index) const { return static_cast(AddressOf(index)); } + /*! \brief A size function used by InplaceArrayBase */ + uint64_t GetSize() const { return size_; } + + protected: + friend class MapNode; + friend class DenseMapNode; + friend class runtime::InplaceArrayBase; }; -/*! \brief map node content */ -class MapNode : public Object { +/*! \brief A specialization of hash map that implements the idea of array-based hash map. + * Another reference implementation can be found [1]. + * + * A. Overview + * + * DenseMapNode did several improvements over traditional separate chaining hash, + * in terms of cache locality, memory footprints and data organization. + * + * A1. Implicit linked list. For better cache locality, instead of using linked list + * explicitly for each bucket, we store list data into a single array that spans contiguously + * in memory, and then carefully design access patterns to make sure most of them fall into + * a single cache line. + * + * A2. 1-byte metadata. There is only 1 byte overhead for each slot in the array to indexing and + * traversal. This can be divided in 3 parts. + * 1) Reserved code: (0b11111111)_2 indicates a slot is empty; (0b11111110)_2 indicates protected, + * which means the slot is empty but not allowed to be written. + * 2) If not empty or protected, the highest bit is used to indicate whether data in the slot is + * head of a linked list. + * 3) The rest 7 bits are used as the "next pointer" (i.e. pointer to the next element). On 64-bit + * architecture, an ordinary pointer can take up to 8 bytes, which is not acceptable overhead when + * dealing with 16-byte ObjectRef pairs. Based on a commonly noticed fact that the lists are + * relatively short (length <= 3) in hash maps, we follow [1]'s idea that only allows the pointer to + * be one of the 126 possible values, i.e. if the next element of i-th slot is (i + x)-th element, + * then x must be one of the 126 pre-defined values. + * + * A3. Data blocking. We organize the array in the way that every 16 elements forms a data block. + * The 16-byte metadata of those 16 elements are stored together, followed by the real data, i.e. + * 16 key-value pairs. + * + * B. Implementation details + * + * B1. Power-of-2 table size and Fibonacci Hashing. We use power-of-two as table size to avoid + * modulo for more efficient arithmetics. To make the hash-to-slot mapping distribute more evenly, + * we use the Fibonacci Hashing [2] trick. + * + * B2. Traverse a linked list in the array. + * 1) List head. Assume Fibonacci Hashing maps a given key to slot i, if metadata at slot i + * indicates that it is list head, then we found the head; otherwise the list is empty. No probing + * is done in this procedure. 2) Next element. To find the next element of a non-empty slot i, we + * look at the last 7 bits of the metadata at slot i. If they are all zeros, then it is the end of + * list; otherwise, we know that the next element is (i + candidates[the-last-7-bits]). + * + * B3. InsertMaybeReHash an element. Following B2, we first traverse the linked list to see if this + * element is in the linked list, and if not, we put it at the end by probing the next empty + * position in one of the 126 candidate positions. If the linked list does not even exist, but the + * slot for list head has been occupied by another linked list, we should find this intruder another + * place. + * + * B4. Quadratic probing with triangle numbers. In open address hashing, it is provable that probing + * with triangle numbers can traverse power-of-2-sized table [3]. In our algorithm, we follow the + * suggestion in [1] that also use triangle numbers for "next pointer" as well as sparing for list + * head. + * + * [1] https://github.com/skarupke/flat_hash_map + * [2] https://programmingpraxis.com/2018/06/19/fibonacci-hash/ + * [3] https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ + */ +class DenseMapNode : public MapNode { + private: + /*! \brief The number of elements in a memory block */ + static constexpr int kBlockCap = 16; + /*! \brief Maximum load factor of the hash map */ + static constexpr double kMaxLoadFactor = 0.99; + /*! \brief Binary representation of the metadata of an empty slot */ + static constexpr uint8_t kEmptySlot = uint8_t(0b11111111); + /*! \brief Binary representation of the metadata of a protected slot */ + static constexpr uint8_t kProtectedSlot = uint8_t(0b11111110); + /*! \brief Number of probing choices available */ + static constexpr int kNumJumpDists = 126; + /*! \brief Head of the implicit linked list */ + struct ListNode; + /*! \brief POD type of a block of memory */ + struct Block { + uint8_t bytes[kBlockCap + kBlockCap * sizeof(KVType)]; + }; + static_assert(sizeof(Block) == kBlockCap * (sizeof(KVType) + 1), "sizeof(Block) incorrect"); + static_assert(std::is_standard_layout::value, "Block is not standard layout"); + public: - /*! \brief The corresponding conatiner type */ - using ContainerType = std::unordered_map; + using MapNode::iterator; - /*! \brief the data content */ - ContainerType data; + /*! + * \brief Destroy the DenseMapNode + */ + ~DenseMapNode() { this->Reset(); } + /*! \return The number of elements of the key */ + size_t count(const key_type& key) const { return !Search(key).IsNone(); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const { return At(key); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key) { return At(key); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) const { + ListNode node = Search(key); + return node.IsNone() ? end() : iterator(node.index, this); + } + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position) { + uint64_t index = position.index; + if (position.self != nullptr && index <= this->slots_) { + Erase(ListNode(index, this)); + } + } + /*! \return begin iterator */ + iterator begin() const { + if (slots_ == 0) { + return iterator(0, this); + } + for (uint64_t index = 0; index <= slots_; ++index) { + if (!ListNode(index, this).IsEmpty()) { + return iterator(index, this); + } + } + return iterator(slots_ + 1, this); + } + /*! \return end iterator */ + iterator end() const { return slots_ == 0 ? iterator(0, this) : iterator(slots_ + 1, this); } - static constexpr const char* _type_key = "Map"; - TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); + private: + /*! + * \brief Search for the given key + * \param key The key + * \return ListNode that associated with the key + */ + ListNode Search(const key_type& key) const { + if (this->size_ == 0) { + return ListNode(); + } + for (ListNode iter = GetListHead(ObjectHash()(key)); !iter.IsNone(); iter.MoveToNext(this)) { + if (ObjectEqual()(key, iter.Key())) { + return iter; + } + } + return ListNode(); + } + /*! + * \brief Search for the given key, throw exception if not exists + * \param key The key + * \return ListNode that associated with the key + */ + mapped_type& At(const key_type& key) const { + ListNode iter = Search(key); + CHECK(!iter.IsNone()) << "IndexError: key is not in Map"; + return iter.Val(); + } + /*! + * \brief Try to insert a key, or do nothing if already exists + * \param key The indexing key + * \param result The linked-list entry found or just constructed + * \return A boolean, indicating if actual insertion happens + */ + bool TryInsert(const key_type& key, ListNode* result) { + if (slots_ == 0) { + return false; + } + // required that `iter` to be the head of a linked list through which we can iterator + ListNode iter = IndexFromHash(ObjectHash()(key)); + // `iter` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list + // Case 1: empty + if (iter.IsEmpty()) { + iter.NewHead(KVType(key, ObjectRef(nullptr))); + this->size_ += 1; + *result = iter; + return true; + } + // Case 2: body of an irrelevant list + if (!iter.IsHead()) { + // we move the elements around and construct the single-element linked list + return IsFull() ? false : TrySpareListHead(iter, key, result); + } + // Case 3: head of the relevant list + // we iterate through the linked list until the end + // make sure `iter` is the previous element of `next` + ListNode next = iter; + do { + // find equal item, do not insert + if (ObjectEqual()(key, next.Key())) { + *result = next; + return true; + } + // make sure `iter` is the previous element of `next` + iter = next; + } while (next.MoveToNext(this)); + // `iter` is the tail of the linked list + // always check capacity before insertion + if (IsFull()) { + return false; + } + // find the next empty slot + uint8_t jump; + if (!iter.GetNextEmpty(this, &jump, result)) { + return false; + } + result->NewTail(KVType(key, ObjectRef(nullptr))); + // link `iter` to `empty`, and move forward + iter.SetJump(jump); + this->size_ += 1; + return true; + } + /*! + * \brief Spare an entry to be the head of a linked list. + * As described in B3, during insertion, it is possible that the entire linked list does not + * exist, but the slot of its head has been occupied by other linked lists. In this case, we need + * to spare the slot by moving away the elements to another valid empty one to make insertion + * possible. + * \param target The given entry to be spared + * \param key The indexing key + * \param result The linked-list entry constructed as the head + * \return A boolean, if actual insertion happens + */ + bool TrySpareListHead(ListNode target, const key_type& key, ListNode* result) { + // `target` is not the head of the linked list + // move the original item of `target` (if any) + // and construct new item on the position `target` + // To make `target` empty, we + // 1) find `w` the previous element of `target` in the linked list + // 2) copy the linked list starting from `r = target` + // 3) paste them after `w` + // read from the linked list after `r` + ListNode r = target; + // write to the tail of `w` + ListNode w = target.FindPrev(this); + // after `target` is moved, we disallow writing to the slot + bool is_first = true; + uint8_t r_meta, jump; + ListNode empty; + do { + // `jump` describes how `w` is jumped to `empty` + // rehash if there is no empty space after `w` + if (!w.GetNextEmpty(this, &jump, &empty)) { + return false; + } + // move `r` to `empty` + empty.NewTail(std::move(r.Data())); + // clear the metadata of `r` + r_meta = r.Meta(); + if (is_first) { + is_first = false; + r.SetProtected(); + } else { + r.SetEmpty(); + } + // link `w` to `empty`, and move forward + w.SetJump(jump); + w = empty; + // move `r` forward as well + } while (r.MoveToNext(this, r_meta)); + // finally we have done moving the linked list + // fill data_ into `target` + target.NewHead(KVType(key, ObjectRef(nullptr))); + this->size_ += 1; + *result = target; + return true; + } + /*! + * \brief Remove a ListNode + * \param iter The node to be removed + */ + void Erase(const ListNode& iter) { + this->size_ -= 1; + if (!iter.HasNext()) { + // `iter` is the last + if (!iter.IsHead()) { + // cut the link if there is any + iter.FindPrev(this).SetJump(0); + } + iter.Data().KVType::~KVType(); + iter.SetEmpty(); + } else { + ListNode last = iter, prev = iter; + for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) { + } + iter.Data() = std::move(last.Data()); + last.SetEmpty(); + prev.SetJump(0); + } + } + /*! \brief Clear the container to empty, release all entries and memory acquired */ + void Reset() { + uint64_t n_blocks = CalcNumBlocks(this->slots_); + for (uint64_t bi = 0; bi < n_blocks; ++bi) { + uint8_t* meta_ptr = data_[bi].bytes; + KVType* data_ptr = reinterpret_cast(data_[bi].bytes + kBlockCap); + for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { + uint8_t& meta = *meta_ptr; + if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) { + meta = uint8_t(kEmptySlot); + data_ptr->KVType::~KVType(); + } + } + } + ReleaseMemory(); + } + /*! \brief Release the memory acquired by the container without deleting its entries stored inside + */ + void ReleaseMemory() { + delete[] data_; + data_ = nullptr; + slots_ = 0; + size_ = 0; + fib_shift_ = 63; + } + /*! + * \brief Create an empty container + * \param fib_shift The fib shift provided + * \param n_slots Number of slots required, should be power-of-two + * \return The object created + */ + static ObjectPtr Empty(uint32_t fib_shift, uint64_t n_slots) { + CHECK_GT(n_slots, uint64_t(SmallMapNode::kMaxSize)); + CHECK_EQ((n_slots & -n_slots), n_slots); + ObjectPtr p = make_object(); + uint64_t n_blocks = CalcNumBlocks(n_slots - 1); + Block* block = p->data_ = new Block[n_blocks]; + p->slots_ = n_slots - 1; + p->size_ = 0; + p->fib_shift_ = fib_shift; + for (uint64_t i = 0; i < n_blocks; ++i, ++block) { + std::fill(block->bytes, block->bytes + kBlockCap, uint8_t(kEmptySlot)); + } + return p; + } + /*! + * \brief Create an empty container with elements copying from another DenseMapNode + * \param from The source container + * \return The object created + */ + static ObjectPtr CopyFrom(DenseMapNode* from) { + ObjectPtr p = make_object(); + uint64_t n_blocks = CalcNumBlocks(from->slots_); + p->data_ = new Block[n_blocks]; + p->slots_ = from->slots_; + p->size_ = from->size_; + p->fib_shift_ = from->fib_shift_; + for (uint64_t bi = 0; bi < n_blocks; ++bi) { + uint8_t* meta_ptr_from = from->data_[bi].bytes; + KVType* data_ptr_from = reinterpret_cast(from->data_[bi].bytes + kBlockCap); + uint8_t* meta_ptr_to = p->data_[bi].bytes; + KVType* data_ptr_to = reinterpret_cast(p->data_[bi].bytes + kBlockCap); + for (int j = 0; j < kBlockCap; + ++j, ++meta_ptr_from, ++data_ptr_from, ++meta_ptr_to, ++data_ptr_to) { + uint8_t& meta = *meta_ptr_to = *meta_ptr_from; + CHECK(meta != kProtectedSlot); + if (meta != uint8_t(kEmptySlot)) { + new (data_ptr_to) KVType(*data_ptr_from); + } + } + } + return p; + } + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { + DenseMapNode* map_node = static_cast(map->get()); + ListNode iter; + // Try to insert. If succeed, we simply return + if (map_node->TryInsert(kv.first, &iter)) { + iter.Val() = kv.second; + return; + } + CHECK_GT(map_node->slots_, uint64_t(SmallMapNode::kMaxSize)); + // Otherwise, start rehash + ObjectPtr p = Empty(map_node->fib_shift_ - 1, map_node->slots_ * 2 + 2); + // Insert the given `kv` into the new hash map + InsertMaybeReHash(kv, &p); + uint64_t n_blocks = CalcNumBlocks(map_node->slots_); + // Then Insert data from the original block. + for (uint64_t bi = 0; bi < n_blocks; ++bi) { + uint8_t* meta_ptr = map_node->data_[bi].bytes; + KVType* data_ptr = reinterpret_cast(map_node->data_[bi].bytes + kBlockCap); + for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { + uint8_t& meta = *meta_ptr; + if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) { + meta = uint8_t(kEmptySlot); + KVType kv = std::move(*data_ptr); + InsertMaybeReHash(kv, &p); + } + } + } + map_node->ReleaseMemory(); + *map = p; + } + /*! + * \brief Check whether the hash table is full + * \return A boolean indicating whether hash table is full + */ + bool IsFull() const { return size_ + 1 > (slots_ + 1) * kMaxLoadFactor; } + /*! + * \brief Increment the pointer + * \param index The pointer to be incremented + * \return The increased pointer + */ + uint64_t IncItr(uint64_t index) const { + for (++index; index <= slots_; ++index) { + if (!ListNode(index, this).IsEmpty()) { + return index; + } + } + return slots_ + 1; + } + /*! + * \brief Decrement the pointer + * \param index The pointer to be decremented + * \return The decreased pointer + */ + uint64_t DecItr(uint64_t index) const { + while (index != 0) { + index -= 1; + if (!ListNode(index, this).IsEmpty()) { + return index; + } + } + return slots_ + 1; + } + /*! + * \brief De-reference the pointer + * \param index The pointer to be dereferenced + * \return The result + */ + KVType* DeRefItr(uint64_t index) const { return &ListNode(index, this).Data(); } + /*! \brief Construct from hash code */ + ListNode IndexFromHash(uint64_t hash_value) const { + return ListNode(FibHash(hash_value, fib_shift_), this); + } + /*! \brief Construct from hash code if the position is head of list */ + ListNode GetListHead(uint64_t hash_value) const { + ListNode node = IndexFromHash(hash_value); + return node.IsHead() ? node : ListNode(); + } + /*! \brief Construct the number of blocks in the hash table */ + static uint64_t CalcNumBlocks(uint64_t n_slots_m1) { + uint64_t n_slots = n_slots_m1 > 0 ? n_slots_m1 + 1 : 0; + return (n_slots + kBlockCap - 1) / kBlockCap; + } + /*! + * \brief Calculate the power-of-2 table size given the lower-bound of required capacity. + * \param cap The lower-bound of the required capacity + * \param fib_shift The result shift for Fibonacci Hashing + * \param n_slots The result number of slots + */ + static void CalcTableSize(uint64_t cap, uint32_t* fib_shift, uint64_t* n_slots) { + uint32_t shift = 64; + uint64_t slots = 1; + for (uint64_t c = cap; c; c >>= 1) { + shift -= 1; + slots <<= 1; + } + CHECK_GT(slots, cap); + if (slots < cap * 2) { + *fib_shift = shift - 1; + *n_slots = slots << 1; + } else { + *fib_shift = shift; + *n_slots = slots; + } + } + /*! + * \brief Fibonacci Hashing, maps a hash code to an index in a power-of-2-sized table. + * See also: https://programmingpraxis.com/2018/06/19/fibonacci-hash/. + * \param hash_value The raw hash value + * \param fib_shift The shift in Fibonacci Hashing + * \return An index calculated using Fibonacci Hashing + */ + static uint64_t FibHash(uint64_t hash_value, uint32_t fib_shift) { + constexpr uint64_t coeff = 11400714819323198485ull; + return (coeff * hash_value) >> fib_shift; + } + /*! \brief The implicit in-place linked list used to index a chain */ + struct ListNode { + /*! \brief Construct None */ + ListNode() : index(0), block(nullptr) {} + /*! \brief Construct from position */ + ListNode(uint64_t index, const DenseMapNode* self) + : index(index), block(self->data_ + (index / kBlockCap)) {} + /*! \brief Metadata on the entry */ + uint8_t& Meta() const { return *(block->bytes + index % kBlockCap); } + /*! \brief Data on the entry */ + KVType& Data() const { + return *(reinterpret_cast(block->bytes + kBlockCap + + (index % kBlockCap) * sizeof(KVType))); + } + /*! \brief Key on the entry */ + key_type& Key() const { return Data().first; } + /*! \brief Value on the entry */ + mapped_type& Val() const { return Data().second; } + /*! \brief If the entry is head of linked list */ + bool IsHead() const { return (Meta() & 0b10000000) == 0b00000000; } + /*! \brief If the entry is none */ + bool IsNone() const { return block == nullptr; } + /*! \brief If the entry is empty slot */ + bool IsEmpty() const { return Meta() == uint8_t(kEmptySlot); } + /*! \brief If the entry is protected slot */ + bool IsProtected() const { return Meta() == uint8_t(kProtectedSlot); } + /*! \brief Set the entry to be empty */ + void SetEmpty() const { Meta() = uint8_t(kEmptySlot); } + /*! \brief Set the entry to be protected */ + void SetProtected() const { Meta() = uint8_t(kProtectedSlot); } + /*! \brief Set the entry's jump to its next entry */ + void SetJump(uint8_t jump) const { (Meta() &= 0b10000000) |= jump; } + /*! \brief Construct a head of linked list in-place */ + void NewHead(KVType v) const { + Meta() = 0b00000000; + new (&Data()) KVType(std::move(v)); + } + /*! \brief Construct a tail of linked list in-place */ + void NewTail(KVType v) const { + Meta() = 0b10000000; + new (&Data()) KVType(std::move(v)); + } + /*! \brief If the entry has next entry on the linked list */ + bool HasNext() const { return kNextProbeLocation[Meta() & 0b01111111] != 0; } + /*! \brief Move the entry to the next entry on the linked list */ + bool MoveToNext(const DenseMapNode* self, uint8_t meta) { + uint64_t offset = kNextProbeLocation[meta & 0b01111111]; + if (offset == 0) { + index = 0; + block = nullptr; + return false; + } + index = (index + offset) & (self->slots_); + block = self->data_ + (index / kBlockCap); + return true; + } + /*! \brief Move the entry to the next entry on the linked list */ + bool MoveToNext(const DenseMapNode* self) { return MoveToNext(self, Meta()); } + /*! \brief Get the previous entry on the linked list */ + ListNode FindPrev(const DenseMapNode* self) const { + // start from the head of the linked list, which must exist + ListNode next = self->IndexFromHash(ObjectHash()(Key())); + // `prev` is always the previous item of `next` + ListNode prev = next; + for (next.MoveToNext(self); index != next.index; prev = next, next.MoveToNext(self)) { + } + return prev; + } + /*! \brief Get the next empty jump */ + bool GetNextEmpty(const DenseMapNode* self, uint8_t* jump, ListNode* result) const { + for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) { + ListNode candidate((index + kNextProbeLocation[idx]) & (self->slots_), self); + if (candidate.IsEmpty()) { + *jump = idx; + *result = candidate; + return true; + } + } + return false; + } + /*! \brief Index on the real array */ + uint64_t index; + /*! \brief Pointer to the actual block */ + Block* block; + }; + + protected: + /*! \brief fib shift in Fibonacci Hashing */ + uint32_t fib_shift_; + /*! \brief array of data blocks */ + Block* data_; + /* clang-format off */ + /*! \brief Candidates of probing distance */ + TVM_DLL static constexpr uint64_t kNextProbeLocation[kNumJumpDists] { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + // Quadratic probing with triangle numbers. See also: + // 1) https://en.wikipedia.org/wiki/Quadratic_probing + // 2) https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ + // 3) https://github.com/skarupke/flat_hash_map + 21, 28, 36, 45, 55, 66, 78, 91, 105, 120, + 136, 153, 171, 190, 210, 231, 253, 276, 300, 325, + 351, 378, 406, 435, 465, 496, 528, 561, 595, 630, + 666, 703, 741, 780, 820, 861, 903, 946, 990, 1035, + 1081, 1128, 1176, 1225, 1275, 1326, 1378, 1431, 1485, 1540, + 1596, 1653, 1711, 1770, 1830, 1891, 1953, 2016, 2080, 2145, + 2211, 2278, 2346, 2415, 2485, 2556, 2628, + // larger triangle numbers + 8515, 19110, 42778, 96141, 216153, + 486591, 1092981, 2458653, 5532801, 12442566, + 27993903, 62983476, 141717030, 318844378, 717352503, + 1614057336, 3631522476, 8170957530, 18384510628, 41364789378, + 93070452520, 209408356380, 471168559170, 1060128894105, 2385289465695, + 5366898840628, 12075518705635, 27169915244790, 61132312065111, 137547689707000, + 309482283181501, 696335127828753, 1566753995631385, 3525196511162271, 7931691992677701, + 17846306936293605, 40154190677507445, 90346928918121501, 203280589587557251, 457381325854679626, + 1029107982097042876, 2315492959180353330, 5209859154120846435, + }; + /* clang-format on */ + friend class MapNode; }; +#define _TVM_DISPATCH_MAP(base, var, body) \ + { \ + using TSmall = SmallMapNode*; \ + using TDense = DenseMapNode*; \ + uint64_t slots = base->slots_; \ + if (slots <= SmallMapNode::kMaxSize) { \ + TSmall var = static_cast(base); \ + body; \ + } else { \ + TDense var = static_cast(base); \ + body; \ + } \ + } + +#define _TVM_DISPATCH_MAP_CONST(base, var, body) \ + { \ + using TSmall = const SmallMapNode*; \ + using TDense = const DenseMapNode*; \ + uint64_t slots = base->slots_; \ + if (slots <= SmallMapNode::kMaxSize) { \ + TSmall var = static_cast(base); \ + body; \ + } else { \ + TDense var = static_cast(base); \ + body; \ + } \ + } + +inline MapNode::iterator::pointer MapNode::iterator::operator->() const { + _TVM_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); }); +} + +inline MapNode::iterator& MapNode::iterator::operator++() { + _TVM_DISPATCH_MAP_CONST(self, p, { + index = p->IncItr(index); + return *this; + }); +} + +inline MapNode::iterator& MapNode::iterator::operator--() { + _TVM_DISPATCH_MAP_CONST(self, p, { + index = p->IncItr(index); + return *this; + }); +} + +inline size_t MapNode::count(const key_type& key) const { + _TVM_DISPATCH_MAP_CONST(this, p, { return p->count(key); }); +} + +inline const MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) const { + _TVM_DISPATCH_MAP_CONST(this, p, { return p->at(key); }); +} + +inline MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) { + _TVM_DISPATCH_MAP(this, p, { return p->at(key); }); +} + +inline MapNode::iterator MapNode::begin() const { + _TVM_DISPATCH_MAP_CONST(this, p, { return p->begin(); }); +} + +inline MapNode::iterator MapNode::end() const { + _TVM_DISPATCH_MAP_CONST(this, p, { return p->end(); }); +} + +inline MapNode::iterator MapNode::find(const MapNode::key_type& key) const { + _TVM_DISPATCH_MAP_CONST(this, p, { return p->find(key); }); +} + +inline void MapNode::erase(const MapNode::iterator& position) { + _TVM_DISPATCH_MAP(this, p, { return p->erase(position); }); +} + +#undef _TVM_DISPATCH_MAP +#undef _TVM_DISPATCH_MAP_CONST + +inline ObjectPtr MapNode::Empty() { return SmallMapNode::Empty(); } + +inline ObjectPtr MapNode::CopyFrom(MapNode* from) { + if (from->slots_ <= SmallMapNode::kMaxSize) { + return SmallMapNode::CopyFrom(static_cast(from)); + } else { + return DenseMapNode::CopyFrom(static_cast(from)); + } +} + +template +inline ObjectPtr MapNode::CreateFromRange(IterType first, IterType last) { + int64_t _cap = std::distance(first, last); + if (_cap < 0) { + return SmallMapNode::Empty(); + } + uint64_t cap = static_cast(_cap); + if (cap < SmallMapNode::kMaxSize) { + return SmallMapNode::CreateFromRange(cap, first, last); + } + uint32_t fib_shift; + uint64_t n_slots; + DenseMapNode::CalcTableSize(cap, &fib_shift, &n_slots); + ObjectPtr obj = DenseMapNode::Empty(fib_shift, n_slots); + for (; first != last; ++first) { + KVType kv(*first); + DenseMapNode::InsertMaybeReHash(kv, &obj); + } + return obj; +} + +inline void MapNode::InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { + constexpr uint64_t kSmallMapMaxSize = SmallMapNode::kMaxSize; + MapNode* base = static_cast(map->get()); + if (base->slots_ < kSmallMapMaxSize) { + SmallMapNode::InsertMaybeReHash(kv, map); + } else if (base->slots_ == kSmallMapMaxSize) { + if (base->size_ < base->slots_) { + SmallMapNode::InsertMaybeReHash(kv, map); + } else { + ObjectPtr new_map = MapNode::CreateFromRange(base->begin(), base->end()); + DenseMapNode::InsertMaybeReHash(kv, &new_map); + *map = std::move(new_map); + } + } else { + DenseMapNode::InsertMaybeReHash(kv, map); + } +} + +namespace runtime { +template <> +inline ObjectPtr make_object<>() = delete; +} // namespace runtime + +#endif + /*! * \brief Map container of NodeRef->NodeRef in DSL graph. * Map implements copy on write semantics, which means map is mutable @@ -102,22 +1252,40 @@ template ::value>::type> class Map : public ObjectRef { public: + using key_type = K; + using mapped_type = V; + class iterator; /*! * \brief default constructor */ - Map() { data_ = make_object(); } + Map() { data_ = MapNode::Empty(); } /*! * \brief move constructor * \param other source */ - Map(Map&& other) { // NOLINT(*) - data_ = std::move(other.data_); - } + Map(Map&& other) { data_ = std::move(other.data_); } /*! * \brief copy constructor * \param other source */ - Map(const Map& other) : ObjectRef(other.data_) { // NOLINT(*) + Map(const Map& other) : ObjectRef(other.data_) {} + /*! + * \brief copy assign operator + * \param other The source of assignment + * \return reference to self. + */ + Map& operator=(Map&& other) { + data_ = std::move(other.data_); + return *this; + } + /*! + * \brief move assign operator + * \param other The source of assignment + * \return reference to self. + */ + Map& operator=(const Map& other) { + data_ = other.data_; + return *this; } /*! * \brief constructor from pointer @@ -132,14 +1300,14 @@ class Map : public ObjectRef { */ template Map(IterType begin, IterType end) { - assign(begin, end); + data_ = MapNode::CreateFromRange(begin, end); } /*! * \brief constructor from initializer list * \param init The initalizer list */ - Map(std::initializer_list > init) { // NOLINT(*) - assign(init.begin(), init.end()); + Map(std::initializer_list> init) { + data_ = MapNode::CreateFromRange(init.begin(), init.end()); } /*! * \brief constructor from unordered_map @@ -147,66 +1315,50 @@ class Map : public ObjectRef { */ template Map(const std::unordered_map& init) { // NOLINT(*) - assign(init.begin(), init.end()); - } - /*! - * \brief move assign operator - * \param other The source of assignment - * \return reference to self. - */ - Map& operator=(Map&& other) { - data_ = std::move(other.data_); - return *this; - } - /*! - * \brief copy assign operator - * \param other The source of assignment - * \return reference to self. - */ - Map& operator=(const Map& other) { - data_ = other.data_; - return *this; - } - /*! - * \brief reset the array to content from iterator. - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - void assign(IterType begin, IterType end) { - ObjectPtr n = make_object(); - for (IterType i = begin; i != end; ++i) { - n->data.emplace(std::make_pair(i->first, i->second)); - } - data_ = std::move(n); + data_ = MapNode::CreateFromRange(init.begin(), init.end()); } /*! * \brief Read element from map. * \param key The key * \return the corresonding element. */ - inline const V operator[](const K& key) const { - return DowncastNoCheck(static_cast(data_.get())->data.at(key)); - } + const V at(const K& key) const { return DowncastNoCheck(GetMapNode()->at(key)); } /*! * \brief Read element from map. * \param key The key * \return the corresonding element. */ - inline const V at(const K& key) const { - return DowncastNoCheck(static_cast(data_.get())->data.at(key)); - } + const V operator[](const K& key) const { return this->at(key); } /*! \return The size of the array */ - inline size_t size() const { - if (data_.get() == nullptr) return 0; - return static_cast(data_.get())->data.size(); + size_t size() const { + MapNode* n = GetMapNode(); + return n == nullptr ? 0 : n->size(); } /*! \return The number of elements of the key */ - inline size_t count(const K& key) const { - if (data_.get() == nullptr) return 0; - return static_cast(data_.get())->data.count(key); + size_t count(const K& key) const { + MapNode* n = GetMapNode(); + return n == nullptr ? 0 : GetMapNode()->count(key); } + /*! \return whether array is empty */ + bool empty() const { return size() == 0; } + /*! + * \brief set the Map. + * \param key The index key. + * \param value The value to be setted. + */ + void Set(const K& key, const V& value) { + CopyOnWrite(); + MapNode::InsertMaybeReHash(MapNode::KVType(key, value), &data_); + } + /*! \return begin iterator */ + iterator begin() const { return iterator(GetMapNode()->begin()); } + /*! \return end iterator */ + iterator end() const { return iterator(GetMapNode()->end()); } + /*! \return find the key and returns the associated iterator */ + iterator find(const K& key) const { return iterator(GetMapNode()->find(key)); } + + void erase(const K& key) { CopyOnWrite()->erase(key); } + /*! * \brief copy on write semantics * Do nothing if current handle is the unique copy of the array. @@ -215,50 +1367,64 @@ class Map : public ObjectRef { * * \return Handle to the internal node container(which ganrantees to be unique) */ - inline MapNode* CopyOnWrite() { - if (data_.get() == nullptr || !data_.unique()) { - ObjectPtr n = make_object(); - n->data = static_cast(data_.get())->data; - ObjectPtr(std::move(n)).swap(data_); + MapNode* CopyOnWrite() { + if (data_.get() == nullptr) { + data_ = MapNode::Empty(); + } else if (!data_.unique()) { + data_ = MapNode::CopyFrom(GetMapNode()); } - return static_cast(data_.get()); - } - /*! - * \brief set the Map. - * \param key The index key. - * \param value The value to be setted. - */ - inline void Set(const K& key, const V& value) { - MapNode* n = this->CopyOnWrite(); - n->data[key] = value; + return GetMapNode(); } - - /*! \return whether array is empty */ - inline bool empty() const { return size() == 0; } /*! \brief specify container node */ using ContainerType = MapNode; - struct ValueConverter { - using ResultType = std::pair; - static inline ResultType convert(const std::pair& n) { - return std::make_pair(DowncastNoCheck(n.first), DowncastNoCheck(n.second)); + /*! \brief Iterator of the hash map */ + class iterator { + public: + using iterator_category = std::bidirectional_iterator_tag; + using difference_type = int64_t; + using value_type = const std::pair; + using pointer = value_type*; + using reference = value_type; + + iterator() : itr() {} + + /*! \brief Compare iterators */ + bool operator==(const iterator& other) const { return itr == other.itr; } + /*! \brief Compare iterators */ + bool operator!=(const iterator& other) const { return itr != other.itr; } + /*! \brief De-reference iterators is not allowed */ + pointer operator->() const = delete; + /*! \brief De-reference iterators */ + reference operator*() const { + auto& kv = *itr; + return std::make_pair(DowncastNoCheck(kv.first), DowncastNoCheck(kv.second)); + } + /*! \brief Prefix self increment, e.g. ++iter */ + iterator& operator++() { + ++itr; + return *this; + } + /*! \brief Suffix self increment */ + iterator operator++(int) { + iterator copy = *this; + ++(*this); + return copy; } - }; - using iterator = IterAdapter; + private: + iterator(const MapNode::iterator& itr) // NOLINT(*) + : itr(itr) {} - /*! \return begin iterator */ - inline iterator begin() const { - return iterator(static_cast(data_.get())->data.begin()); - } - /*! \return end iterator */ - inline iterator end() const { - return iterator(static_cast(data_.get())->data.end()); - } - /*! \return begin iterator */ - inline iterator find(const K& key) const { - return iterator(static_cast(data_.get())->data.find(key)); - } + template + friend class Map; + + MapNode::iterator itr; + }; + + private: + /*! \brief Return data_ as type of pointer of MapNode */ + MapNode* GetMapNode() const { return static_cast(data_.get()); } }; } // namespace tvm @@ -267,7 +1433,7 @@ namespace tvm { namespace runtime { // Additional overloads for PackedFunc checking. template -struct ObjectTypeChecker > { +struct ObjectTypeChecker> { static bool Check(const Object* ptr) { if (ptr == nullptr) return true; if (!ptr->IsInstance()) return false; @@ -283,12 +1449,12 @@ struct ObjectTypeChecker > { }; template -struct ObjectTypeChecker > { +struct ObjectTypeChecker> { static bool Check(const Object* ptr) { if (ptr == nullptr) return true; if (!ptr->IsInstance()) return false; const MapNode* n = static_cast(ptr); - for (const auto& kv : n->data) { + for (const auto& kv : *n) { if (!ObjectTypeChecker::Check(kv.first.get())) return false; if (!ObjectTypeChecker::Check(kv.second.get())) return false; } diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index fe23ace..5467ae4 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -34,6 +34,7 @@ #include #include #include +#include // We use c++14 std::experimental::string_view for optimizing hash computation // only right now, its usage is limited in this file. Any broader usage of // std::experiment in our core codebase is discouraged and needs community @@ -71,10 +72,28 @@ class StringRef; } // namespace llvm namespace tvm { +namespace runtime { -struct ObjectEqual; +/*! \brief String-aware ObjectRef equal functor */ +struct ObjectHash { + /*! + * \brief Calculate the hash code of an ObjectRef + * \param a The given ObjectRef + * \return Hash code of a, string hash for strings and pointer address otherwise. + */ + size_t operator()(const ObjectRef& a) const; +}; -namespace runtime { +/*! \brief String-aware ObjectRef hash functor */ +struct ObjectEqual { + /*! + * \brief Check if the two ObjectRef are equal + * \param a One ObjectRef + * \param b The other ObjectRef + * \return String equality if both are strings, pointer address equality otherwise. + */ + bool operator()(const ObjectRef& a, const ObjectRef& b) const; +}; /*! * \brief Base template for classes with array like memory layout. @@ -209,7 +228,7 @@ class IterAdapter { using difference_type = typename std::iterator_traits::difference_type; using value_type = typename Converter::ResultType; using pointer = typename Converter::ResultType*; - using reference = typename Converter::ResultType&; // NOLINT(*) + using reference = typename Converter::ResultType&; using iterator_category = typename std::iterator_traits::iterator_category; explicit IterAdapter(TIter iter) : iter_(iter) {} @@ -221,12 +240,12 @@ class IterAdapter { --iter_; return *this; } - IterAdapter& operator++(int) { + IterAdapter operator++(int) { IterAdapter copy = *this; ++iter_; return copy; } - IterAdapter& operator--(int) { + IterAdapter operator--(int) { IterAdapter copy = *this; --iter_; return copy; @@ -536,6 +555,7 @@ template ::value>::type> class Array : public ObjectRef { public: + using value_type = T; // constructors /*! * \brief default constructor @@ -1329,7 +1349,7 @@ class String : public ObjectRef { friend String operator+(const String& lhs, const char* rhs); friend String operator+(const char* lhs, const String& rhs); - friend struct tvm::ObjectEqual; + friend struct tvm::runtime::ObjectEqual; }; /*! \brief An object representing string moved from std::string. */ @@ -1484,6 +1504,25 @@ inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, s } } +inline size_t ObjectHash::operator()(const ObjectRef& a) const { + if (const auto* str = a.as()) { + return String::HashBytes(str->data, str->size); + } + return ObjectPtrHash()(a); +} + +inline bool ObjectEqual::operator()(const ObjectRef& a, const ObjectRef& b) const { + if (a.same_as(b)) { + return true; + } + if (const auto* str_a = a.as()) { + if (const auto* str_b = b.as()) { + return String::memncmp(str_a->data, str_b->data, str_a->size, str_b->size) == 0; + } + } + return false; +} + template <> struct PackedFuncValueConverter<::tvm::runtime::String> { static String From(const TVMArgValue& val) { diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 483ad6b..851c4ad 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -66,6 +66,8 @@ struct TypeIndex { kRuntimeString = 3, /*! \brief runtime::Array. */ kRuntimeArray = 4, + /*! \brief runtime::Map. */ + kRuntimeMap = 5, // static assignments that may subject to change. kRuntimeClosure, kRuntimeADT, @@ -590,6 +592,7 @@ class ObjectRef { friend struct ObjectPtrHash; friend class TVMRetValue; friend class TVMArgsSetter; + friend class ObjectInternal; template friend SubRef Downcast(BaseRef ref); }; diff --git a/src/ir/expr.cc b/src/ir/expr.cc index b032fe5..fd380aa 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -171,8 +171,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '{'; - for (auto it = op->data.begin(); it != op->data.end(); ++it) { - if (it != op->data.begin()) { + for (auto it = op->begin(); it != op->end(); ++it) { + if (it != op->begin()) { p->stream << ", "; } if (it->first->IsInstance()) { diff --git a/src/ir/module.cc b/src/ir/module.cc index 0d6eeb1..25ecab2 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -278,9 +278,9 @@ void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) void IRModuleNode::Remove(const GlobalVar& var) { auto functions_node = this->functions.CopyOnWrite(); - functions_node->data.erase(var); + functions_node->erase(var); auto gvar_node = global_var_map_.CopyOnWrite(); - gvar_node->data.erase(var->name_hint); + gvar_node->erase(var->name_hint); } BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const { diff --git a/src/node/container.cc b/src/node/container.cc index bdebb7f..60b5f40 100644 --- a/src/node/container.cc +++ b/src/node/container.cc @@ -198,7 +198,7 @@ struct MapNodeTrait { // parameters. using KV = std::pair; std::vector temp; - for (const auto& kv : key->data) { + for (const auto& kv : *key) { size_t hashed_value; if (hash_reduce->LookupHashedValue(kv.first, &hashed_value)) { temp.emplace_back(hashed_value, kv.second); @@ -208,7 +208,7 @@ struct MapNodeTrait { std::sort(temp.begin(), temp.end(), [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); // add size to the hash - hash_reduce(static_cast(key->data.size())); + hash_reduce(static_cast(key->size())); // hash the content for (size_t i = 0; i < temp.size();) { size_t k = i + 1; @@ -230,7 +230,7 @@ struct MapNodeTrait { // parameters. using KV = std::pair; std::vector temp; - for (const auto& kv : key->data) { + for (const auto& kv : *key) { temp.push_back(std::make_pair(Downcast(kv.first), kv.second)); } // sort by the hash key of the keys. @@ -238,7 +238,7 @@ struct MapNodeTrait { [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); // NOTE: we won't have ties // add size to the hash after sorting. - hash_reduce(static_cast(key->data.size())); + hash_reduce(static_cast(key->size())); // hash the content for (size_t i = 0; i < temp.size(); ++i) { hash_reduce(temp[i].first); @@ -247,7 +247,7 @@ struct MapNodeTrait { } static void SHashReduce(const MapNode* key, SHashReducer hash_reduce) { - bool is_str_map = std::all_of(key->data.begin(), key->data.end(), [](const auto& v) { + bool is_str_map = std::all_of(key->begin(), key->end(), [](const auto& v) { return v.first->template IsInstance(); }); if (is_str_map) { @@ -258,35 +258,35 @@ struct MapNodeTrait { } static bool SEqualReduceForOMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { - for (const auto& kv : lhs->data) { + for (const auto& kv : *lhs) { // Only allow equal checking if the keys are already mapped // This resolves common use cases where we want to store // Map where Var is defined in the function // parameters. ObjectRef rhs_key = equal->MapLhsToRhs(kv.first); if (!rhs_key.defined()) return false; - auto it = rhs->data.find(rhs_key); - if (it == rhs->data.end()) return false; + auto it = rhs->find(rhs_key); + if (it == rhs->end()) return false; if (!equal(kv.second, it->second)) return false; } return true; } static bool SEqualReduceForSMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { - for (const auto& kv : lhs->data) { - auto it = rhs->data.find(kv.first); - if (it == rhs->data.end()) return false; + for (const auto& kv : *lhs) { + auto it = rhs->find(kv.first); + if (it == rhs->end()) return false; if (!equal(kv.second, it->second)) return false; } return true; } static bool SEqualReduce(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { - if (rhs->data.size() != lhs->data.size()) return false; - if (rhs->data.size() == 0) return true; - bool ls = std::all_of(lhs->data.begin(), lhs->data.end(), + if (rhs->size() != lhs->size()) return false; + if (rhs->size() == 0) return true; + bool ls = std::all_of(lhs->begin(), lhs->end(), [](const auto& v) { return v.first->template IsInstance(); }); - bool rs = std::all_of(rhs->data.begin(), rhs->data.end(), + bool rs = std::all_of(rhs->begin(), rhs->end(), [](const auto& v) { return v.first->template IsInstance(); }); if (ls != rs) { return false; @@ -297,22 +297,18 @@ struct MapNodeTrait { TVM_REGISTER_OBJECT_TYPE(MapNode); TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait) - .set_creator([](const std::string&) -> ObjectPtr { - return ::tvm::runtime::make_object(); - }); + .set_creator([](const std::string&) -> ObjectPtr { return MapNode::Empty(); }); TVM_REGISTER_GLOBAL("node.Map").set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args.size() % 2, 0); - MapNode::ContainerType data; + std::unordered_map data; for (int i = 0; i < args.num_args; i += 2) { ObjectRef k = String::CanConvertFrom(args[i]) ? args[i].operator String() : args[i].operator ObjectRef(); ObjectRef v = args[i + 1]; data.emplace(std::move(k), std::move(v)); } - auto node = make_object(); - node->data = std::move(data); - *ret = Map(node); + *ret = Map(std::move(data)); }); TVM_REGISTER_GLOBAL("node.MapSize").set_body([](TVMArgs args, TVMRetValue* ret) { @@ -320,7 +316,7 @@ TVM_REGISTER_GLOBAL("node.MapSize").set_body([](TVMArgs args, TVMRetValue* ret) Object* ptr = static_cast(args[0].value().v_handle); CHECK(ptr->IsInstance()); auto* n = static_cast(ptr); - *ret = static_cast(n->data.size()); + *ret = static_cast(n->size()); }); TVM_REGISTER_GLOBAL("node.MapGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { @@ -329,9 +325,9 @@ TVM_REGISTER_GLOBAL("node.MapGetItem").set_body([](TVMArgs args, TVMRetValue* re CHECK(ptr->IsInstance()); auto* n = static_cast(ptr); - auto it = n->data.find(String::CanConvertFrom(args[1]) ? args[1].operator String() - : args[1].operator ObjectRef()); - CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; + auto it = n->find(String::CanConvertFrom(args[1]) ? args[1].operator String() + : args[1].operator ObjectRef()); + CHECK(it != n->end()) << "cannot find the corresponding key in the Map"; *ret = (*it).second; }); @@ -340,8 +336,8 @@ TVM_REGISTER_GLOBAL("node.MapCount").set_body([](TVMArgs args, TVMRetValue* ret) Object* ptr = static_cast(args[0].value().v_handle); CHECK(ptr->IsInstance()); const MapNode* n = static_cast(ptr); - int64_t cnt = n->data.count(String::CanConvertFrom(args[1]) ? args[1].operator String() - : args[1].operator ObjectRef()); + int64_t cnt = n->count(String::CanConvertFrom(args[1]) ? args[1].operator String() + : args[1].operator ObjectRef()); *ret = cnt; }); @@ -350,7 +346,7 @@ TVM_REGISTER_GLOBAL("node.MapItems").set_body([](TVMArgs args, TVMRetValue* ret) Object* ptr = static_cast(args[0].value().v_handle); auto* n = static_cast(ptr); Array rkvs; - for (const auto& kv : n->data) { + for (const auto& kv : *n) { if (kv.first->IsInstance()) { rkvs.push_back(Downcast(kv.first)); } else { @@ -360,4 +356,8 @@ TVM_REGISTER_GLOBAL("node.MapItems").set_body([](TVMArgs args, TVMRetValue* ret) } *ret = std::move(rkvs); }); + +#if (USE_FALLBACK_STL_MAP == 0) +TVM_DLL constexpr uint64_t DenseMapNode::kNextProbeLocation[]; +#endif } // namespace tvm diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 8de21da..ec82c91 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -234,7 +234,7 @@ ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, runtime::TVMArgsSetter setter(values.data(), tcodes.data()); int index = 0; - for (auto& kv : static_cast(kwargs.get())->data) { + for (const auto& kv : *static_cast(kwargs.get())) { setter(index, Downcast(kv.first).c_str()); setter(index + 1, kv.second); index += 2; diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 3866533..4382579 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -35,6 +35,7 @@ #include #include +#include "../runtime/object_internal.h" #include "../support/base64.h" namespace tvm { @@ -110,15 +111,15 @@ class NodeIndexer : public AttrVisitor { } } else if (node->IsInstance()) { MapNode* n = static_cast(node); - bool is_str_map = std::all_of(n->data.begin(), n->data.end(), [](const auto& v) { + bool is_str_map = std::all_of(n->begin(), n->end(), [](const auto& v) { return v.first->template IsInstance(); }); if (is_str_map) { - for (const auto& kv : n->data) { + for (const auto& kv : *n) { MakeIndex(const_cast(kv.second.get())); } } else { - for (const auto& kv : n->data) { + for (const auto& kv : *n) { MakeIndex(const_cast(kv.first.get())); MakeIndex(const_cast(kv.second.get())); } @@ -147,6 +148,11 @@ struct JSONNode { std::vector keys; /*! \brief values of a map or array. */ std::vector data; + /*! + * \brief field member dependency. + * NOTE: This is an auxiliary data structure for loading, and it won't be serialized to json. + */ + std::vector fields; void Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); @@ -253,16 +259,16 @@ class JSONAttrGetter : public AttrVisitor { } } else if (node->IsInstance()) { MapNode* n = static_cast(node); - bool is_str_map = std::all_of(n->data.begin(), n->data.end(), [](const auto& v) { + bool is_str_map = std::all_of(n->begin(), n->end(), [](const auto& v) { return v.first->template IsInstance(); }); if (is_str_map) { - for (const auto& kv : n->data) { + for (const auto& kv : *n) { node_->keys.push_back(Downcast(kv.first)); node_->data.push_back(node_index_->at(const_cast(kv.second.get()))); } } else { - for (const auto& kv : n->data) { + for (const auto& kv : *n) { node_->data.push_back(node_index_->at(const_cast(kv.first.get()))); node_->data.push_back(node_index_->at(const_cast(kv.second.get()))); } @@ -274,19 +280,71 @@ class JSONAttrGetter : public AttrVisitor { } }; +class FieldDependencyFinder : public AttrVisitor { + public: + JSONNode* jnode_; + ReflectionVTable* reflection_ = ReflectionVTable::Global(); + + std::string GetValue(const char* key) const { + auto it = jnode_->attrs.find(key); + if (it == jnode_->attrs.end()) { + LOG(FATAL) << "JSONReader: cannot find field " << key; + } + return it->second; + } + template + void ParseValue(const char* key, T* value) const { + std::istringstream is(GetValue(key)); + is >> *value; + if (is.fail()) { + LOG(FATAL) << "Wrong value format for field " << key; + } + } + void Visit(const char* key, double* value) final {} + void Visit(const char* key, int64_t* value) final {} + void Visit(const char* key, uint64_t* value) final {} + void Visit(const char* key, int* value) final {} + void Visit(const char* key, bool* value) final {} + void Visit(const char* key, std::string* value) final {} + void Visit(const char* key, void** value) final {} + void Visit(const char* key, DataType* value) final {} + void Visit(const char* key, runtime::NDArray* value) final {} + void Visit(const char* key, ObjectRef* value) final { + size_t index; + ParseValue(key, &index); + jnode_->fields.push_back(index); + } + void Find(Object* node, JSONNode* jnode) { + // Skip None + if (node == nullptr) { + return; + } + // Skip the objects that have their own string repr + if (jnode->repr_bytes.length() > 0 || reflection_->GetReprBytes(node, nullptr)) { + return; + } + // Skip containers + if (jnode->type_key == ArrayNode::_type_key || jnode->type_key == MapNode::_type_key) { + return; + } + jnode_ = jnode; + reflection_->VisitAttrs(node, this); + } +}; + // Helper class to set the attributes of a node // from given json node. class JSONAttrSetter : public AttrVisitor { public: const std::vector>* node_list_; const std::vector* tensor_list_; - JSONNode* node_; + JSONNode* jnode_; ReflectionVTable* reflection_ = ReflectionVTable::Global(); std::string GetValue(const char* key) const { - auto it = node_->attrs.find(key); - if (it == node_->attrs.end()) { + auto it = jnode_->attrs.find(key); + if (it == jnode_->attrs.end()) { LOG(FATAL) << "JSONReader: cannot find field " << key; } return it->second; @@ -325,33 +383,46 @@ class JSONAttrSetter : public AttrVisitor { *value = ObjectRef(node_list_->at(index)); } // set node to be current JSONNode - void Set(Object* node) { - if (node == nullptr) return; - - if (node->IsInstance()) { - ArrayNode* n = static_cast(node); - CHECK_EQ(n->size(), node_->data.size()); - int64_t i = 0; - for (size_t index : node_->data) { - n->SetItem(i++, ObjectRef(node_list_->at(index))); + void Set(ObjectPtr* node, JSONNode* jnode) { + // Skip None + if (node->get() == nullptr) { + return; + } + // Skip the objects that have their own string repr + if (jnode->repr_bytes.length() > 0 || reflection_->GetReprBytes(node->get(), nullptr)) { + return; + } + // handling Array + if (jnode->type_key == ArrayNode::_type_key) { + std::vector container; + for (auto index : jnode->data) { + container.push_back(ObjectRef(node_list_->at(index))); } - } else if (node->IsInstance()) { - MapNode* n = static_cast(node); - if (node_->keys.empty()) { - CHECK_EQ(node_->data.size() % 2, 0U); - for (size_t i = 0; i < node_->data.size(); i += 2) { - n->data[ObjectRef(node_list_->at(node_->data[i]))] = - ObjectRef(node_list_->at(node_->data[i + 1])); + Array array(container); + *node = runtime::ObjectInternal::MoveObjectPtr(&array); + return; + } + // handling Map + if (jnode->type_key == MapNode::_type_key) { + std::unordered_map container; + if (jnode->keys.empty()) { + CHECK_EQ(jnode->data.size() % 2, 0U); + for (size_t i = 0; i < jnode->data.size(); i += 2) { + container[ObjectRef(node_list_->at(jnode->data[i]))] = + ObjectRef(node_list_->at(jnode->data[i + 1])); } } else { - CHECK_EQ(node_->data.size(), node_->keys.size()); - for (size_t i = 0; i < node_->data.size(); ++i) { - n->data[String(node_->keys[i])] = ObjectRef(node_list_->at(node_->data[i])); + CHECK_EQ(jnode->data.size(), jnode->keys.size()); + for (size_t i = 0; i < jnode->data.size(); ++i) { + container[String(jnode->keys[i])] = ObjectRef(node_list_->at(jnode->data[i])); } } - } else { - reflection_->VisitAttrs(node, this); + Map map(container); + *node = runtime::ObjectInternal::MoveObjectPtr(&map); + return; } + jnode_ = jnode; + reflection_->VisitAttrs(node->get(), this); } }; @@ -413,6 +484,41 @@ struct JSONGraph { } return g; } + + std::vector TopoSort() const { + size_t n_nodes = nodes.size(); + std::vector topo_order; + std::vector in_degree(n_nodes, 0); + for (const JSONNode& jnode : nodes) { + for (size_t i : jnode.data) { + ++in_degree[i]; + } + for (size_t i : jnode.fields) { + ++in_degree[i]; + } + } + for (size_t i = 0; i < n_nodes; ++i) { + if (in_degree[i] == 0) { + topo_order.push_back(i); + } + } + for (size_t p = 0; p < topo_order.size(); ++p) { + const JSONNode& jnode = nodes[topo_order[p]]; + for (size_t i : jnode.data) { + if (--in_degree[i] == 0) { + topo_order.push_back(i); + } + } + for (size_t i : jnode.fields) { + if (--in_degree[i] == 0) { + topo_order.push_back(i); + } + } + } + CHECK_EQ(topo_order.size(), n_nodes) << "Cyclic reference detected in JSON file"; + std::reverse(std::begin(topo_order), std::end(topo_order)); + return topo_order; + } }; std::string SaveJSON(const ObjectRef& n) { @@ -424,14 +530,17 @@ std::string SaveJSON(const ObjectRef& n) { } ObjectRef LoadJSON(std::string json_str) { + ReflectionVTable* reflection = ReflectionVTable::Global(); JSONGraph jgraph; - std::vector> nodes; - std::vector tensors; { // load in json graph. std::istringstream is(json_str); dmlc::JSONReader reader(&is); jgraph.Load(&reader); + } + size_t n_nodes = jgraph.nodes.size(); + std::vector tensors; + { // load in tensors for (const std::string& blob : jgraph.b64ndarrays) { dmlc::MemoryStringStream mstrm(const_cast(&blob)); @@ -439,38 +548,33 @@ ObjectRef LoadJSON(std::string json_str) { b64strm.InitPosition(); runtime::NDArray temp; CHECK(temp.Load(&b64strm)); - tensors.emplace_back(temp); + tensors.emplace_back(std::move(temp)); } } - ReflectionVTable* reflection = ReflectionVTable::Global(); - - // node 0 is always null - nodes.reserve(jgraph.nodes.size()); - - for (const JSONNode& jnode : jgraph.nodes) { - if (jnode.type_key == ArrayNode::_type_key) { - CHECK(jnode.repr_bytes.empty()); - nodes.emplace_back(ArrayNode::CreateRepeated(jnode.data.size(), ObjectRef(nullptr))); - } else if (jnode.type_key.length() != 0) { - ObjectPtr node = reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes); - nodes.emplace_back(std::move(node)); - } else { - nodes.emplace_back(ObjectPtr()); + // Pass 1: create all non-container objects + std::vector> nodes(n_nodes, nullptr); + for (size_t i = 0; i < n_nodes; ++i) { + const JSONNode& jnode = jgraph.nodes[i]; + if (jnode.type_key.length() != 0) { + nodes[i] = reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes); } } - CHECK_EQ(nodes.size(), jgraph.nodes.size()); - JSONAttrSetter setter; - setter.node_list_ = &nodes; - setter.tensor_list_ = &tensors; - - for (size_t i = 0; i < nodes.size(); ++i) { - setter.node_ = &jgraph.nodes[i]; - // Skip the nodes that has an repr bytes representation. - // NOTE: the second condition is used to guard the case - // where the repr bytes itself is an empty string "". - if (setter.node_->repr_bytes.length() == 0 && nodes[i] != nullptr && - !reflection->GetReprBytes(nodes[i].get(), nullptr)) { - setter.Set(nodes[i].get()); + // Pass 2: figure out all field dependency + { + FieldDependencyFinder dep_finder; + for (size_t i = 0; i < n_nodes; ++i) { + dep_finder.Find(nodes[i].get(), &jgraph.nodes[i]); + } + } + // Pass 3: topo sort + std::vector topo_order = jgraph.TopoSort(); + // Pass 4: set all values + { + JSONAttrSetter setter; + setter.node_list_ = &nodes; + setter.tensor_list_ = &tensors; + for (size_t i : topo_order) { + setter.Set(&nodes[i], &jgraph.nodes[i]); } } return ObjectRef(nodes.at(jgraph.root)); diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h index f255b28..beae137 100644 --- a/src/runtime/object_internal.h +++ b/src/runtime/object_internal.h @@ -28,6 +28,7 @@ #include #include +#include namespace tvm { namespace runtime { @@ -83,6 +84,15 @@ class ObjectInternal { // address translation return static_cast(static_cast(handle)); } + /*! + * \brief Move the ObjectPtr inside ObjectRef out + * \param obj The ObjectRef + * \return The result ObjectPtr + */ + static ObjectPtr MoveObjectPtr(ObjectRef* obj) { + ObjectPtr data = std::move(obj->data_); + return data; + } }; } // namespace runtime diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 50106c9..62c790f 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -94,7 +94,7 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) IRModuleNode* mod_ptr = mod.CopyOnWrite(); auto* func_dict = mod_ptr->functions.CopyOnWrite(); // directly loop over the underlying dict - for (auto& kv : func_dict->data) { + for (auto& kv : *func_dict) { // only picks up tir::PrimFunc if (kv.second->IsInstance()) { // move out the function so that it is the only copy. @@ -110,7 +110,7 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) // automatic removal of None for (const auto& gv : deleted_list) { - func_dict->data.erase(gv); + func_dict->erase(gv); } pass_ctx.Trace(mod, pass_info, false); return mod; diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 67336d4..0684189 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -277,7 +277,7 @@ Pass SplitHostDevice() { auto* func_dict = mod_ptr->functions.CopyOnWrite(); IRModule device_mod = IRModule(); - for (auto& kv : func_dict->data) { + for (auto& kv : *func_dict) { if (kv.second->IsInstance()) { PrimFunc func = Downcast(std::move(kv.second)); kv.second = SplitHostDevice(std::move(func), &device_mod); diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index eca65ee..d1d6ffb 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -310,9 +310,7 @@ TEST(Map, Mutate) { CHECK(it != dict.end() && (*it).second.same_as(x)); it = dict2.find(zz); - CHECK(it == dict.end()); - - LOG(INFO) << dict; + CHECK(it == dict2.end()); } TEST(Map, Iterator) { @@ -324,6 +322,51 @@ TEST(Map, Iterator) { CHECK(map2[a].as()->value == 2); } +TEST(Map, Insert) { + using namespace tvm; + auto check = [](const Map& result, + std::unordered_map expected) { + CHECK_EQ(result.size(), expected.size()); + for (const auto& kv : result) { + CHECK(expected.count(kv.first)); + CHECK_EQ(expected[kv.first], kv.second.operator int64_t()); + expected.erase(kv.first); + } + }; + Map result; + std::unordered_map expected; + char key = 'a'; + int64_t val = 1; + for (int i = 0; i < 26; ++i, ++key, ++val) { + std::string s(1, key); + result.Set(s, val); + expected[s] = val; + check(result, expected); + } +} + +TEST(Map, Erase) { + auto check = [](const Map& result, + std::unordered_map expected) { + CHECK_EQ(result.size(), expected.size()); + for (const auto& kv : result) { + CHECK(expected.count(kv.first)); + CHECK_EQ(expected[kv.first], kv.second.operator int64_t()); + expected.erase(kv.first); + } + }; + Map map{{"a", 1}, {"b", 2}, {"c", 3}, {"d", 4}, {"e", 5}}; + std::unordered_map stl(map.begin(), map.end()); + for (char c = 'a'; c <= 'e'; ++c) { + Map result = map; + std::unordered_map expected(stl); + std::string key(1, c); + result.erase(key); + expected.erase(key); + check(result, expected); + } +} + TEST(String, MoveFromStd) { using namespace std; string source = "this is a string"; -- 2.7.4