1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 #ifndef FST_EXTENSIONS_LINEAR_TRIE_H_
5 #define FST_EXTENSIONS_LINEAR_TRIE_H_
7 #include <unordered_map>
11 #include <fst/compat.h>
16 const int kNoTrieNodeId = -1;
18 // Forward declarations of all available trie topologies.
19 template <class L, class H>
20 class NestedTrieTopology;
21 template <class L, class H>
22 class FlatTrieTopology;
24 // A pair of parent node id and label, part of a trie edge
31 ParentLabel(int p, L l) : parent(p), label(l) {}
33 bool operator==(const ParentLabel &that) const {
34 return parent == that.parent && label == that.label;
37 std::istream &Read(std::istream &strm) { // NOLINT
38 ReadType(strm, &parent);
39 ReadType(strm, &label);
43 std::ostream &Write(std::ostream &strm) const { // NOLINT
44 WriteType(strm, parent);
45 WriteType(strm, label);
50 template <class L, class H>
51 struct ParentLabelHash {
52 size_t operator()(const ParentLabel<L> &pl) const {
53 return static_cast<size_t>(pl.parent * 7853 + H()(pl.label));
57 // The trie topology in a nested tree of hash maps; allows efficient
58 // iteration over children of a specific node.
59 template <class L, class H>
60 class NestedTrieTopology {
64 typedef std::unordered_map<L, int, H> NextMap;
66 class const_iterator {
68 typedef std::forward_iterator_tag iterator_category;
69 typedef std::pair<ParentLabel<L>, int> value_type;
70 typedef std::ptrdiff_t difference_type;
71 typedef const value_type *pointer;
72 typedef const value_type &reference;
74 friend class NestedTrieTopology<L, H>;
76 const_iterator() : ptr_(nullptr), cur_node_(kNoTrieNodeId), cur_edge_() {}
78 reference operator*() {
82 pointer operator->() {
87 const_iterator &operator++();
88 const_iterator &operator++(int); // NOLINT
90 bool operator==(const const_iterator &that) const {
91 return ptr_ == that.ptr_ && cur_node_ == that.cur_node_ &&
92 cur_edge_ == that.cur_edge_;
94 bool operator!=(const const_iterator &that) const {
95 return !(*this == that);
99 const_iterator(const NestedTrieTopology *ptr, int cur_node)
100 : ptr_(ptr), cur_node_(cur_node) {
104 void SetProperCurEdge() {
105 if (cur_node_ < ptr_->NumNodes())
106 cur_edge_ = ptr_->nodes_[cur_node_]->begin();
108 cur_edge_ = ptr_->nodes_[0]->begin();
112 stub_.first = ParentLabel<L>(cur_node_, cur_edge_->first);
113 stub_.second = cur_edge_->second;
116 const NestedTrieTopology *ptr_;
118 typename NextMap::const_iterator cur_edge_;
122 NestedTrieTopology();
123 NestedTrieTopology(const NestedTrieTopology &that);
124 ~NestedTrieTopology();
125 void swap(NestedTrieTopology &that);
126 NestedTrieTopology &operator=(const NestedTrieTopology &that);
127 bool operator==(const NestedTrieTopology &that) const;
128 bool operator!=(const NestedTrieTopology &that) const;
130 int Root() const { return 0; }
131 size_t NumNodes() const { return nodes_.size(); }
132 int Insert(int parent, const L &label);
133 int Find(int parent, const L &label) const;
134 const NextMap &ChildrenOf(int parent) const { return *nodes_[parent]; }
136 std::istream &Read(std::istream &strm); // NOLINT
137 std::ostream &Write(std::ostream &strm) const; // NOLINT
139 const_iterator begin() const { return const_iterator(this, 0); }
140 const_iterator end() const { return const_iterator(this, NumNodes()); }
143 std::vector<NextMap *>
144 nodes_; // Use pointers to avoid copying the maps when the
148 template <class L, class H>
149 NestedTrieTopology<L, H>::NestedTrieTopology() {
150 nodes_.push_back(new NextMap);
153 template <class L, class H>
154 NestedTrieTopology<L, H>::NestedTrieTopology(const NestedTrieTopology &that) {
155 nodes_.reserve(that.nodes_.size());
156 for (size_t i = 0; i < that.nodes_.size(); ++i) {
157 NextMap *node = that.nodes_[i];
158 nodes_.push_back(new NextMap(*node));
162 template <class L, class H>
163 NestedTrieTopology<L, H>::~NestedTrieTopology() {
164 for (size_t i = 0; i < nodes_.size(); ++i) {
165 NextMap *node = nodes_[i];
170 // TODO(wuke): std::swap compatibility
171 template <class L, class H>
172 inline void NestedTrieTopology<L, H>::swap(NestedTrieTopology &that) {
173 nodes_.swap(that.nodes_);
176 template <class L, class H>
177 inline NestedTrieTopology<L, H> &NestedTrieTopology<L, H>::operator=(
178 const NestedTrieTopology &that) {
179 NestedTrieTopology copy(that);
184 template <class L, class H>
185 inline bool NestedTrieTopology<L, H>::operator==(
186 const NestedTrieTopology &that) const {
187 if (NumNodes() != that.NumNodes()) return false;
188 for (int i = 0; i < NumNodes(); ++i)
189 if (ChildrenOf(i) != that.ChildrenOf(i)) return false;
193 template <class L, class H>
194 inline bool NestedTrieTopology<L, H>::operator!=(
195 const NestedTrieTopology &that) const {
196 return !(*this == that);
199 template <class L, class H>
200 inline int NestedTrieTopology<L, H>::Insert(int parent, const L &label) {
201 int ret = Find(parent, label);
202 if (ret == kNoTrieNodeId) {
204 (*nodes_[parent])[label] = ret;
205 nodes_.push_back(new NextMap);
210 template <class L, class H>
211 inline int NestedTrieTopology<L, H>::Find(int parent, const L &label) const {
212 typename NextMap::const_iterator it = nodes_[parent]->find(label);
213 return it == nodes_[parent]->end() ? kNoTrieNodeId : it->second;
216 template <class L, class H>
217 inline std::istream &NestedTrieTopology<L, H>::Read(
218 std::istream &strm) { // NOLINT
219 NestedTrieTopology new_trie;
221 if (!ReadType(strm, &num_nodes)) return strm;
222 for (size_t i = 1; i < num_nodes; ++i) new_trie.nodes_.push_back(new NextMap);
223 for (size_t i = 0; i < num_nodes; ++i) ReadType(strm, new_trie.nodes_[i]);
224 if (strm) swap(new_trie);
228 template <class L, class H>
229 inline std::ostream &NestedTrieTopology<L, H>::Write(
230 std::ostream &strm) const { // NOLINT
231 WriteType(strm, NumNodes());
232 for (size_t i = 0; i < NumNodes(); ++i) WriteType(strm, *nodes_[i]);
236 template <class L, class H>
237 inline typename NestedTrieTopology<L, H>::const_iterator
238 &NestedTrieTopology<L, H>::const_iterator::operator++() {
240 if (cur_edge_ == ptr_->nodes_[cur_node_]->end()) {
242 while (cur_node_ < ptr_->NumNodes() && ptr_->nodes_[cur_node_]->empty())
249 template <class L, class H>
250 inline typename NestedTrieTopology<L, H>::const_iterator
251 &NestedTrieTopology<L, H>::const_iterator::operator++(int) { // NOLINT
252 const_iterator save(*this);
257 // The trie topology in a single hash map; only allows iteration over
258 // all the edges in arbitrary order.
259 template <class L, class H>
260 class FlatTrieTopology {
262 typedef std::unordered_map<ParentLabel<L>, int, ParentLabelHash<L, H>>
266 // Iterator over edges as std::pair<ParentLabel<L>, int>
267 typedef typename NextMap::const_iterator const_iterator;
271 FlatTrieTopology() {}
272 FlatTrieTopology(const FlatTrieTopology &that) : next_(that.next_) {}
274 explicit FlatTrieTopology(const T &that);
276 // TODO(wuke): std::swap compatibility
277 void swap(FlatTrieTopology &that) { next_.swap(that.next_); }
279 bool operator==(const FlatTrieTopology &that) const {
280 return next_ == that.next_;
282 bool operator!=(const FlatTrieTopology &that) const {
283 return !(*this == that);
286 int Root() const { return 0; }
287 size_t NumNodes() const { return next_.size() + 1; }
288 int Insert(int parent, const L &label);
289 int Find(int parent, const L &label) const;
291 std::istream &Read(std::istream &strm) { // NOLINT
292 return ReadType(strm, &next_);
294 std::ostream &Write(std::ostream &strm) const { // NOLINT
295 return WriteType(strm, next_);
298 const_iterator begin() const { return next_.begin(); }
299 const_iterator end() const { return next_.end(); }
305 template <class L, class H>
307 FlatTrieTopology<L, H>::FlatTrieTopology(const T &that)
308 : next_(that.begin(), that.end()) {}
310 template <class L, class H>
311 inline int FlatTrieTopology<L, H>::Insert(int parent, const L &label) {
312 int ret = Find(parent, label);
313 if (ret == kNoTrieNodeId) {
315 next_[ParentLabel<L>(parent, label)] = ret;
320 template <class L, class H>
321 inline int FlatTrieTopology<L, H>::Find(int parent, const L &label) const {
322 typename NextMap::const_iterator it =
323 next_.find(ParentLabel<L>(parent, label));
324 return it == next_.end() ? kNoTrieNodeId : it->second;
327 // A collection of implementations of the trie data structure. The key
328 // is a sequence of type `L` which must be hashable. The value is of
329 // `V` which must be default constructible and copyable. In addition,
330 // a value object is stored for each node in the trie therefore
331 // copying `V` should be cheap.
333 // One can access the store values with an integer node id, using the
334 // [] operator. A valid node id can be obtained by the following ways:
336 // 1. Using the `Root()` method to get the node id of the root.
338 // 2. Iterating through 0 to `NumNodes() - 1`. The node ids are dense
339 // so every integer in this range is a valid node id.
341 // 3. Using the node id returned from a successful `Insert()` or
344 // 4. Iterating over the trie edges with an `EdgeIterator` and using
345 // the node ids returned from its `Parent()` and `Child()` methods.
347 // Below is an example of inserting keys into the trie:
349 // const string words[] = {"hello", "health", "jello"};
350 // Trie<char, bool> dict;
351 // for (auto word : words) {
352 // int cur = dict.Root();
353 // for (char c : word) {
354 // cur = dict.Insert(cur, c);
359 // And the following is an example of looking up the longest prefix of
360 // a string using the trie constructed above:
362 // string query = "healed";
363 // size_t prefix_length = 0;
364 // int cur = dict.Find(dict.Root(), query[prefix_length]);
365 // while (prefix_length < query.size() &&
366 // cur != Trie<char, bool>::kNoNodeId) {
368 // cur = dict.Find(cur, query[prefix_length]);
370 template <class L, class V, class T>
373 template <class LL, class VV, class TT>
374 friend class MutableTrie;
380 // Constructs a trie with only the root node.
383 // Conversion from another trie of a possiblly different
384 // topology. The underlying topology must supported conversion.
386 explicit MutableTrie(const MutableTrie<L, V, S> &that)
387 : topology_(that.topology_), values_(that.values_) {}
389 // TODO(wuke): std::swap compatibility
390 void swap(MutableTrie &that) {
391 topology_.swap(that.topology_);
392 values_.swap(that.values_);
395 int Root() const { return topology_.Root(); }
396 size_t NumNodes() const { return topology_.NumNodes(); }
398 // Inserts an edge with given `label` at node `parent`. Returns the
399 // child node id. If the node already exists, returns the node id
401 int Insert(int parent, const L &label) {
402 int ret = topology_.Insert(parent, label);
403 values_.resize(NumNodes());
407 // Finds the node id of the node from `parent` via `label`. Returns
408 // `kNoTrieNodeId` when such a node does not exist.
409 int Find(int parent, const L &label) const {
410 return topology_.Find(parent, label);
413 const T &TrieTopology() const { return topology_; }
415 // Accesses the value stored for the given node.
416 V &operator[](int node_id) { return values_[node_id]; }
417 const V &operator[](int node_id) const { return values_[node_id]; }
419 // Comparison by content
420 bool operator==(const MutableTrie &that) const {
421 return topology_ == that.topology_ && values_ == that.values_;
424 bool operator!=(const MutableTrie &that) const { return !(*this == that); }
426 std::istream &Read(std::istream &strm) { // NOLINT
427 ReadType(strm, &topology_);
428 ReadType(strm, &values_);
431 std::ostream &Write(std::ostream &strm) const { // NOLINT
432 WriteType(strm, topology_);
433 WriteType(strm, values_);
439 std::vector<V> values_;
444 #endif // FST_EXTENSIONS_LINEAR_TRIE_H_