Imported Upstream version 1.6.4
[platform/upstream/openfst.git] / src / include / fst / extensions / linear / trie.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3
4 #ifndef FST_EXTENSIONS_LINEAR_TRIE_H_
5 #define FST_EXTENSIONS_LINEAR_TRIE_H_
6
7 #include <unordered_map>
8 #include <utility>
9 #include <vector>
10
11 #include <fst/compat.h>
12 #include <fst/util.h>
13
14 namespace fst {
15
16 const int kNoTrieNodeId = -1;
17
18 // Forward declarations of all available trie topologies.
19 template <class L, class H>
20 class NestedTrieTopology;
21 template <class L, class H>
22 class FlatTrieTopology;
23
24 // A pair of parent node id and label, part of a trie edge
25 template <class L>
26 struct ParentLabel {
27   int parent;
28   L label;
29
30   ParentLabel() {}
31   ParentLabel(int p, L l) : parent(p), label(l) {}
32
33   bool operator==(const ParentLabel &that) const {
34     return parent == that.parent && label == that.label;
35   }
36
37   std::istream &Read(std::istream &strm) {  // NOLINT
38     ReadType(strm, &parent);
39     ReadType(strm, &label);
40     return strm;
41   }
42
43   std::ostream &Write(std::ostream &strm) const {  // NOLINT
44     WriteType(strm, parent);
45     WriteType(strm, label);
46     return strm;
47   }
48 };
49
50 template <class L, class H>
51 struct ParentLabelHash {
52   size_t operator()(const ParentLabel<L> &pl) const {
53     return static_cast<size_t>(pl.parent * 7853 + H()(pl.label));
54   }
55 };
56
57 // The trie topology in a nested tree of hash maps; allows efficient
58 // iteration over children of a specific node.
59 template <class L, class H>
60 class NestedTrieTopology {
61  public:
62   typedef L Label;
63   typedef H Hash;
64   typedef std::unordered_map<L, int, H> NextMap;
65
66   class const_iterator {
67    public:
68     typedef std::forward_iterator_tag iterator_category;
69     typedef std::pair<ParentLabel<L>, int> value_type;
70     typedef std::ptrdiff_t difference_type;
71     typedef const value_type *pointer;
72     typedef const value_type &reference;
73
74     friend class NestedTrieTopology<L, H>;
75
76     const_iterator() : ptr_(nullptr), cur_node_(kNoTrieNodeId), cur_edge_() {}
77
78     reference operator*() {
79       UpdateStub();
80       return stub_;
81     }
82     pointer operator->() {
83       UpdateStub();
84       return &stub_;
85     }
86
87     const_iterator &operator++();
88     const_iterator &operator++(int);  // NOLINT
89
90     bool operator==(const const_iterator &that) const {
91       return ptr_ == that.ptr_ && cur_node_ == that.cur_node_ &&
92              cur_edge_ == that.cur_edge_;
93     }
94     bool operator!=(const const_iterator &that) const {
95       return !(*this == that);
96     }
97
98    private:
99     const_iterator(const NestedTrieTopology *ptr, int cur_node)
100         : ptr_(ptr), cur_node_(cur_node) {
101       SetProperCurEdge();
102     }
103
104     void SetProperCurEdge() {
105       if (cur_node_ < ptr_->NumNodes())
106         cur_edge_ = ptr_->nodes_[cur_node_]->begin();
107       else
108         cur_edge_ = ptr_->nodes_[0]->begin();
109     }
110
111     void UpdateStub() {
112       stub_.first = ParentLabel<L>(cur_node_, cur_edge_->first);
113       stub_.second = cur_edge_->second;
114     }
115
116     const NestedTrieTopology *ptr_;
117     int cur_node_;
118     typename NextMap::const_iterator cur_edge_;
119     value_type stub_;
120   };
121
122   NestedTrieTopology();
123   NestedTrieTopology(const NestedTrieTopology &that);
124   ~NestedTrieTopology();
125   void swap(NestedTrieTopology &that);
126   NestedTrieTopology &operator=(const NestedTrieTopology &that);
127   bool operator==(const NestedTrieTopology &that) const;
128   bool operator!=(const NestedTrieTopology &that) const;
129
130   int Root() const { return 0; }
131   size_t NumNodes() const { return nodes_.size(); }
132   int Insert(int parent, const L &label);
133   int Find(int parent, const L &label) const;
134   const NextMap &ChildrenOf(int parent) const { return *nodes_[parent]; }
135
136   std::istream &Read(std::istream &strm);         // NOLINT
137   std::ostream &Write(std::ostream &strm) const;  // NOLINT
138
139   const_iterator begin() const { return const_iterator(this, 0); }
140   const_iterator end() const { return const_iterator(this, NumNodes()); }
141
142  private:
143   std::vector<NextMap *>
144       nodes_;  // Use pointers to avoid copying the maps when the
145                // vector grows
146 };
147
148 template <class L, class H>
149 NestedTrieTopology<L, H>::NestedTrieTopology() {
150   nodes_.push_back(new NextMap);
151 }
152
153 template <class L, class H>
154 NestedTrieTopology<L, H>::NestedTrieTopology(const NestedTrieTopology &that) {
155   nodes_.reserve(that.nodes_.size());
156   for (size_t i = 0; i < that.nodes_.size(); ++i) {
157     NextMap *node = that.nodes_[i];
158     nodes_.push_back(new NextMap(*node));
159   }
160 }
161
162 template <class L, class H>
163 NestedTrieTopology<L, H>::~NestedTrieTopology() {
164   for (size_t i = 0; i < nodes_.size(); ++i) {
165     NextMap *node = nodes_[i];
166     delete node;
167   }
168 }
169
170 // TODO(wuke): std::swap compatibility
171 template <class L, class H>
172 inline void NestedTrieTopology<L, H>::swap(NestedTrieTopology &that) {
173   nodes_.swap(that.nodes_);
174 }
175
176 template <class L, class H>
177 inline NestedTrieTopology<L, H> &NestedTrieTopology<L, H>::operator=(
178     const NestedTrieTopology &that) {
179   NestedTrieTopology copy(that);
180   swap(copy);
181   return *this;
182 }
183
184 template <class L, class H>
185 inline bool NestedTrieTopology<L, H>::operator==(
186     const NestedTrieTopology &that) const {
187   if (NumNodes() != that.NumNodes()) return false;
188   for (int i = 0; i < NumNodes(); ++i)
189     if (ChildrenOf(i) != that.ChildrenOf(i)) return false;
190   return true;
191 }
192
193 template <class L, class H>
194 inline bool NestedTrieTopology<L, H>::operator!=(
195     const NestedTrieTopology &that) const {
196   return !(*this == that);
197 }
198
199 template <class L, class H>
200 inline int NestedTrieTopology<L, H>::Insert(int parent, const L &label) {
201   int ret = Find(parent, label);
202   if (ret == kNoTrieNodeId) {
203     ret = NumNodes();
204     (*nodes_[parent])[label] = ret;
205     nodes_.push_back(new NextMap);
206   }
207   return ret;
208 }
209
210 template <class L, class H>
211 inline int NestedTrieTopology<L, H>::Find(int parent, const L &label) const {
212   typename NextMap::const_iterator it = nodes_[parent]->find(label);
213   return it == nodes_[parent]->end() ? kNoTrieNodeId : it->second;
214 }
215
216 template <class L, class H>
217 inline std::istream &NestedTrieTopology<L, H>::Read(
218     std::istream &strm) {  // NOLINT
219   NestedTrieTopology new_trie;
220   size_t num_nodes;
221   if (!ReadType(strm, &num_nodes)) return strm;
222   for (size_t i = 1; i < num_nodes; ++i) new_trie.nodes_.push_back(new NextMap);
223   for (size_t i = 0; i < num_nodes; ++i) ReadType(strm, new_trie.nodes_[i]);
224   if (strm) swap(new_trie);
225   return strm;
226 }
227
228 template <class L, class H>
229 inline std::ostream &NestedTrieTopology<L, H>::Write(
230     std::ostream &strm) const {  // NOLINT
231   WriteType(strm, NumNodes());
232   for (size_t i = 0; i < NumNodes(); ++i) WriteType(strm, *nodes_[i]);
233   return strm;
234 }
235
236 template <class L, class H>
237 inline typename NestedTrieTopology<L, H>::const_iterator
238     &NestedTrieTopology<L, H>::const_iterator::operator++() {
239   ++cur_edge_;
240   if (cur_edge_ == ptr_->nodes_[cur_node_]->end()) {
241     ++cur_node_;
242     while (cur_node_ < ptr_->NumNodes() && ptr_->nodes_[cur_node_]->empty())
243       ++cur_node_;
244     SetProperCurEdge();
245   }
246   return *this;
247 }
248
249 template <class L, class H>
250 inline typename NestedTrieTopology<L, H>::const_iterator
251     &NestedTrieTopology<L, H>::const_iterator::operator++(int) {  // NOLINT
252   const_iterator save(*this);
253   ++(*this);
254   return save;
255 }
256
257 // The trie topology in a single hash map; only allows iteration over
258 // all the edges in arbitrary order.
259 template <class L, class H>
260 class FlatTrieTopology {
261  private:
262   typedef std::unordered_map<ParentLabel<L>, int, ParentLabelHash<L, H>>
263       NextMap;
264
265  public:
266   // Iterator over edges as std::pair<ParentLabel<L>, int>
267   typedef typename NextMap::const_iterator const_iterator;
268   typedef L Label;
269   typedef H Hash;
270
271   FlatTrieTopology() {}
272   FlatTrieTopology(const FlatTrieTopology &that) : next_(that.next_) {}
273   template <class T>
274   explicit FlatTrieTopology(const T &that);
275
276   // TODO(wuke): std::swap compatibility
277   void swap(FlatTrieTopology &that) { next_.swap(that.next_); }
278
279   bool operator==(const FlatTrieTopology &that) const {
280     return next_ == that.next_;
281   }
282   bool operator!=(const FlatTrieTopology &that) const {
283     return !(*this == that);
284   }
285
286   int Root() const { return 0; }
287   size_t NumNodes() const { return next_.size() + 1; }
288   int Insert(int parent, const L &label);
289   int Find(int parent, const L &label) const;
290
291   std::istream &Read(std::istream &strm) {  // NOLINT
292     return ReadType(strm, &next_);
293   }
294   std::ostream &Write(std::ostream &strm) const {  // NOLINT
295     return WriteType(strm, next_);
296   }
297
298   const_iterator begin() const { return next_.begin(); }
299   const_iterator end() const { return next_.end(); }
300
301  private:
302   NextMap next_;
303 };
304
305 template <class L, class H>
306 template <class T>
307 FlatTrieTopology<L, H>::FlatTrieTopology(const T &that)
308     : next_(that.begin(), that.end()) {}
309
310 template <class L, class H>
311 inline int FlatTrieTopology<L, H>::Insert(int parent, const L &label) {
312   int ret = Find(parent, label);
313   if (ret == kNoTrieNodeId) {
314     ret = NumNodes();
315     next_[ParentLabel<L>(parent, label)] = ret;
316   }
317   return ret;
318 }
319
320 template <class L, class H>
321 inline int FlatTrieTopology<L, H>::Find(int parent, const L &label) const {
322   typename NextMap::const_iterator it =
323       next_.find(ParentLabel<L>(parent, label));
324   return it == next_.end() ? kNoTrieNodeId : it->second;
325 }
326
327 // A collection of implementations of the trie data structure. The key
328 // is a sequence of type `L` which must be hashable. The value is of
329 // `V` which must be default constructible and copyable. In addition,
330 // a value object is stored for each node in the trie therefore
331 // copying `V` should be cheap.
332 //
333 // One can access the store values with an integer node id, using the
334 // [] operator. A valid node id can be obtained by the following ways:
335 //
336 // 1. Using the `Root()` method to get the node id of the root.
337 //
338 // 2. Iterating through 0 to `NumNodes() - 1`. The node ids are dense
339 // so every integer in this range is a valid node id.
340 //
341 // 3. Using the node id returned from a successful `Insert()` or
342 // `Find()` call.
343 //
344 // 4. Iterating over the trie edges with an `EdgeIterator` and using
345 // the node ids returned from its `Parent()` and `Child()` methods.
346 //
347 // Below is an example of inserting keys into the trie:
348 //
349 //   const string words[] = {"hello", "health", "jello"};
350 //   Trie<char, bool> dict;
351 //   for (auto word : words) {
352 //     int cur = dict.Root();
353 //     for (char c : word) {
354 //       cur = dict.Insert(cur, c);
355 //     }
356 //     dict[cur] = true;
357 //   }
358 //
359 // And the following is an example of looking up the longest prefix of
360 // a string using the trie constructed above:
361 //
362 //   string query = "healed";
363 //   size_t prefix_length = 0;
364 //   int cur = dict.Find(dict.Root(), query[prefix_length]);
365 //   while (prefix_length < query.size() &&
366 //     cur != Trie<char, bool>::kNoNodeId) {
367 //     ++prefix_length;
368 //     cur = dict.Find(cur, query[prefix_length]);
369 //   }
370 template <class L, class V, class T>
371 class MutableTrie {
372  public:
373   template <class LL, class VV, class TT>
374   friend class MutableTrie;
375
376   typedef L Label;
377   typedef V Value;
378   typedef T Topology;
379
380   // Constructs a trie with only the root node.
381   MutableTrie() {}
382
383   // Conversion from another trie of a possiblly different
384   // topology. The underlying topology must supported conversion.
385   template <class S>
386   explicit MutableTrie(const MutableTrie<L, V, S> &that)
387       : topology_(that.topology_), values_(that.values_) {}
388
389   // TODO(wuke): std::swap compatibility
390   void swap(MutableTrie &that) {
391     topology_.swap(that.topology_);
392     values_.swap(that.values_);
393   }
394
395   int Root() const { return topology_.Root(); }
396   size_t NumNodes() const { return topology_.NumNodes(); }
397
398   // Inserts an edge with given `label` at node `parent`. Returns the
399   // child node id. If the node already exists, returns the node id
400   // right away.
401   int Insert(int parent, const L &label) {
402     int ret = topology_.Insert(parent, label);
403     values_.resize(NumNodes());
404     return ret;
405   }
406
407   // Finds the node id of the node from `parent` via `label`. Returns
408   // `kNoTrieNodeId` when such a node does not exist.
409   int Find(int parent, const L &label) const {
410     return topology_.Find(parent, label);
411   }
412
413   const T &TrieTopology() const { return topology_; }
414
415   // Accesses the value stored for the given node.
416   V &operator[](int node_id) { return values_[node_id]; }
417   const V &operator[](int node_id) const { return values_[node_id]; }
418
419   // Comparison by content
420   bool operator==(const MutableTrie &that) const {
421     return topology_ == that.topology_ && values_ == that.values_;
422   }
423
424   bool operator!=(const MutableTrie &that) const { return !(*this == that); }
425
426   std::istream &Read(std::istream &strm) {  // NOLINT
427     ReadType(strm, &topology_);
428     ReadType(strm, &values_);
429     return strm;
430   }
431   std::ostream &Write(std::ostream &strm) const {  // NOLINT
432     WriteType(strm, topology_);
433     WriteType(strm, values_);
434     return strm;
435   }
436
437  private:
438   T topology_;
439   std::vector<V> values_;
440 };
441
442 }  // namespace fst
443
444 #endif  // FST_EXTENSIONS_LINEAR_TRIE_H_