Imported Upstream version 1.6.4
[platform/upstream/openfst.git] / src / include / fst / bi-table.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Classes for representing a bijective mapping between an arbitrary entry
5 // of type T and a signed integral ID.
6
7 #ifndef FST_LIB_BI_TABLE_H_
8 #define FST_LIB_BI_TABLE_H_
9
10 #include <deque>
11 #include <memory>
12 #include <functional>
13 #include <unordered_map>
14 #include <unordered_set>
15 #include <vector>
16
17 #include <fst/log.h>
18 #include <fst/memory.h>
19
20 namespace fst {
21
22 // Bitables model bijective mappings between entries of an arbitrary type T and
23 // an signed integral ID of type I. The IDs are allocated starting from 0 in
24 // order.
25 //
26 // template <class I, class T>
27 // class BiTable {
28 //  public:
29 //
30 //   // Required constructors.
31 //   BiTable();
32 //
33 //   // Looks up integer ID from entry. If it doesn't exist and insert
34 //   / is true, adds it; otherwise, returns -1.
35 //   I FindId(const T &entry, bool insert = true);
36 //
37 //   // Looks up entry from integer ID.
38 //   const T &FindEntry(I) const;
39 //
40 //   // Returns number of stored entries.
41 //   I Size() const;
42 // };
43
44 // An implementation using a hash map for the entry to ID mapping. H is the
45 // hash function and E is the equality function. If passed to the constructor,
46 // ownership is given to this class.
47 template <class I, class T, class H, class E = std::equal_to<T>>
48 class HashBiTable {
49  public:
50   // Reserves space for table_size elements. If passing H and E to the
51   // constructor, this class owns them.
52   explicit HashBiTable(size_t table_size = 0, H *h = nullptr, E *e = nullptr) :
53       hash_func_(h ? h : new H()), hash_equal_(e ? e : new E()),
54       entry2id_(table_size, *hash_func_, *hash_equal_) {
55     if (table_size) id2entry_.reserve(table_size);
56   }
57
58   HashBiTable(const HashBiTable<I, T, H, E> &table)
59       : hash_func_(new H(*table.hash_func_)),
60         hash_equal_(new E(*table.hash_equal_)),
61         entry2id_(table.entry2id_.begin(), table.entry2id_.end(),
62                   table.entry2id_.size(), *hash_func_, *hash_equal_),
63         id2entry_(table.id2entry_) {}
64
65   I FindId(const T &entry, bool insert = true) {
66     if (!insert) {
67       const auto it = entry2id_.find(entry);
68       return it == entry2id_.end() ? -1 : it->second - 1;
69     }
70     I &id_ref = entry2id_[entry];
71     if (id_ref == 0) {  // T not found; stores and assigns a new ID.
72       id2entry_.push_back(entry);
73       id_ref = id2entry_.size();
74     }
75     return id_ref - 1;  // NB: id_ref = ID + 1.
76   }
77
78   const T &FindEntry(I s) const { return id2entry_[s]; }
79
80   I Size() const { return id2entry_.size(); }
81
82   // TODO(riley): Add fancy clear-to-size, as in CompactHashBiTable.
83   void Clear() {
84     entry2id_.clear();
85     id2entry_.clear();
86   }
87
88  private:
89   std::unique_ptr<H> hash_func_;
90   std::unique_ptr<E> hash_equal_;
91   std::unordered_map<T, I, H, E> entry2id_;
92   std::vector<T> id2entry_;
93 };
94
95 // Enables alternative hash set representations below.
96 enum HSType { HS_STL = 0, HS_DENSE = 1, HS_SPARSE = 2, HS_FLAT = 3 };
97
98 // Default hash set is STL hash_set.
99 template <class K, class H, class E, HSType HS>
100 struct HashSet : public std::unordered_set<K, H, E, PoolAllocator<K>> {
101   explicit HashSet(size_t n = 0, const H &h = H(), const E &e = E())
102       : std::unordered_set<K, H, E, PoolAllocator<K>>(n, h, e) {}
103
104   void rehash(size_t n) {}
105 };
106
107 // An implementation using a hash set for the entry to ID mapping. The hash set
108 // holds keys which are either the ID or kCurrentKey. These keys can be mapped
109 // to entries either by looking up in the entry vector or, if kCurrentKey, in
110 // current_entry_. The hash and key equality functions map to entries first. H
111 // is the hash function and E is the equality function. If passed to the
112 // constructor, ownership is given to this class.
113 template <class I, class T, class H, class E = std::equal_to<T>,
114           HSType HS = HS_FLAT>
115 class CompactHashBiTable {
116  public:
117   friend class HashFunc;
118   friend class HashEqual;
119
120   // Reserves space for table_size elements. If passing H and E to the
121   // constructor, this class owns them.
122   explicit CompactHashBiTable(size_t table_size = 0, H *h = nullptr,
123                               E *e = nullptr) :
124         hash_func_(h ? h : new H()), hash_equal_(e ? e : new E()),
125         compact_hash_func_(*this), compact_hash_equal_(*this),
126         keys_(table_size, compact_hash_func_, compact_hash_equal_) {
127     if (table_size) id2entry_.reserve(table_size);
128   }
129
130   CompactHashBiTable(const CompactHashBiTable<I, T, H, E, HS> &table)
131       : hash_func_(new H(*table.hash_func_)),
132         hash_equal_(new E(*table.hash_equal_)),
133         compact_hash_func_(*this), compact_hash_equal_(*this),
134         keys_(table.keys_.size(), compact_hash_func_, compact_hash_equal_),
135         id2entry_(table.id2entry_) {
136     keys_.insert(table.keys_.begin(), table.keys_.end());
137   }
138
139   I FindId(const T &entry, bool insert = true) {
140     current_entry_ = &entry;
141     if (insert) {
142       auto result = keys_.insert(kCurrentKey);
143       if (!result.second) return *result.first;  // Already exists.
144       // Overwrites kCurrentKey with a new key value; this is safe because it
145       // doesn't affect hashing or equality testing.
146       I key = id2entry_.size();
147       const_cast<I &>(*result.first) = key;
148       id2entry_.push_back(entry);
149       return key;
150     }
151     const auto it = keys_.find(kCurrentKey);
152     return it == keys_.end() ? -1 : *it;
153   }
154
155   const T &FindEntry(I s) const { return id2entry_[s]; }
156
157   I Size() const { return id2entry_.size(); }
158
159   // Clears content; with argument, erases last n IDs.
160   void Clear(ssize_t n = -1) {
161     if (n < 0 || n >= id2entry_.size()) {  // Clears completely.
162       keys_.clear();
163       id2entry_.clear();
164     } else if (n == id2entry_.size() - 1) {  // Leaves only key 0.
165       const T entry = FindEntry(0);
166       keys_.clear();
167       id2entry_.clear();
168       FindId(entry, true);
169     } else {
170       while (n-- > 0) {
171         I key = id2entry_.size() - 1;
172         keys_.erase(key);
173         id2entry_.pop_back();
174       }
175       keys_.rehash(0);
176     }
177   }
178
179  private:
180   static constexpr I kCurrentKey = -1;
181   static constexpr I kEmptyKey = -2;
182   static constexpr I kDeletedKey = -3;
183
184   class HashFunc {
185    public:
186     explicit HashFunc(const CompactHashBiTable &ht) : ht_(&ht) {}
187
188     size_t operator()(I k) const {
189       if (k >= kCurrentKey) {
190         return (*ht_->hash_func_)(ht_->Key2Entry(k));
191       } else {
192         return 0;
193       }
194     }
195
196    private:
197     const CompactHashBiTable *ht_;
198   };
199
200   class HashEqual {
201    public:
202     explicit HashEqual(const CompactHashBiTable &ht) : ht_(&ht) {}
203
204     bool operator()(I k1, I k2) const {
205       if (k1 == k2) {
206         return true;
207       } else if (k1 >= kCurrentKey && k2 >= kCurrentKey) {
208         return (*ht_->hash_equal_)(ht_->Key2Entry(k1), ht_->Key2Entry(k2));
209       } else {
210         return false;
211       }
212     }
213
214    private:
215     const CompactHashBiTable *ht_;
216   };
217
218   using KeyHashSet = HashSet<I, HashFunc, HashEqual, HS>;
219
220   const T &Key2Entry(I k) const {
221     if (k == kCurrentKey) {
222       return *current_entry_;
223     } else {
224       return id2entry_[k];
225     }
226   }
227
228   std::unique_ptr<H> hash_func_;
229   std::unique_ptr<E> hash_equal_;
230   HashFunc compact_hash_func_;
231   HashEqual compact_hash_equal_;
232   KeyHashSet keys_;
233   std::vector<T> id2entry_;
234   const T *current_entry_;
235 };
236
237 template <class I, class T, class H, class E, HSType HS>
238 constexpr I CompactHashBiTable<I, T, H, E, HS>::kCurrentKey;
239
240 template <class I, class T, class H, class E, HSType HS>
241 constexpr I CompactHashBiTable<I, T, H, E, HS>::kEmptyKey;
242
243 template <class I, class T, class H, class E, HSType HS>
244 constexpr I CompactHashBiTable<I, T, H, E, HS>::kDeletedKey;
245
246 // An implementation using a vector for the entry to ID mapping. It is passed a
247 // function object FP that should fingerprint entries uniquely to an integer
248 // that can used as a vector index. Normally, VectorBiTable constructs the FP
249 // object. The user can instead pass in this object; in that case, VectorBiTable
250 // takes its ownership.
251 template <class I, class T, class FP>
252 class VectorBiTable {
253  public:
254   // Reserves table_size cells of space. If passing FP argument to the
255   // constructor, this class owns it.
256   explicit VectorBiTable(FP *fp = nullptr, size_t table_size = 0) :
257       fp_(fp ? fp : new FP()) {
258     if (table_size) id2entry_.reserve(table_size);
259   }
260
261   VectorBiTable(const VectorBiTable<I, T, FP> &table)
262       : fp_(new FP(*table.fp_)), fp2id_(table.fp2id_),
263         id2entry_(table.id2entry_) {}
264
265   I FindId(const T &entry, bool insert = true) {
266     ssize_t fp = (*fp_)(entry);
267     if (fp >= fp2id_.size()) fp2id_.resize(fp + 1);
268     I &id_ref = fp2id_[fp];
269     if (id_ref == 0) {  // T not found.
270       if (insert) {     // Stores and assigns a new ID.
271         id2entry_.push_back(entry);
272         id_ref = id2entry_.size();
273       } else {
274         return -1;
275       }
276     }
277     return id_ref - 1;  // NB: id_ref = ID + 1.
278   }
279
280   const T &FindEntry(I s) const { return id2entry_[s]; }
281
282   I Size() const { return id2entry_.size(); }
283
284   const FP &Fingerprint() const { return *fp_; }
285
286  private:
287   std::unique_ptr<FP> fp_;
288   std::vector<I> fp2id_;
289   std::vector<T> id2entry_;
290 };
291
292 // An implementation using a vector and a compact hash table. The selecting
293 // functor S returns true for entries to be hashed in the vector. The
294 // fingerprinting functor FP returns a unique fingerprint for each entry to be
295 // hashed in the vector (these need to be suitable for indexing in a vector).
296 // The hash functor H is used when hashing entry into the compact hash table.
297 // If passed to the constructor, ownership is given to this class.
298 template <class I, class T, class S, class FP, class H, HSType HS = HS_DENSE>
299 class VectorHashBiTable {
300  public:
301   friend class HashFunc;
302   friend class HashEqual;
303
304   explicit VectorHashBiTable(S *s, FP *fp, H *h, size_t vector_size = 0,
305                              size_t entry_size = 0)
306       : selector_(s), fp_(fp), h_(h), hash_func_(*this), hash_equal_(*this),
307         keys_(0, hash_func_, hash_equal_) {
308     if (vector_size) fp2id_.reserve(vector_size);
309     if (entry_size) id2entry_.reserve(entry_size);
310   }
311
312   VectorHashBiTable(const VectorHashBiTable<I, T, S, FP, H, HS> &table)
313       : selector_(new S(table.s_)), fp_(new FP(*table.fp_)),
314         h_(new H(*table.h_)), id2entry_(table.id2entry_),
315         fp2id_(table.fp2id_), hash_func_(*this), hash_equal_(*this),
316         keys_(table.keys_.size(), hash_func_, hash_equal_) {
317     keys_.insert(table.keys_.begin(), table.keys_.end());
318   }
319
320   I FindId(const T &entry, bool insert = true) {
321     if ((*selector_)(entry)) {  // Uses the vector if selector_(entry) == true.
322       uint64 fp = (*fp_)(entry);
323       if (fp2id_.size() <= fp) fp2id_.resize(fp + 1, 0);
324       if (fp2id_[fp] == 0) {  // T not found.
325         if (insert) {         // Stores and assigns a new ID.
326           id2entry_.push_back(entry);
327           fp2id_[fp] = id2entry_.size();
328         } else {
329           return -1;
330         }
331       }
332       return fp2id_[fp] - 1;  // NB: assoc_value = ID + 1.
333     } else {                  // Uses the hash table otherwise.
334       current_entry_ = &entry;
335       const auto it = keys_.find(kCurrentKey);
336       if (it == keys_.end()) {
337         if (insert) {
338           I key = id2entry_.size();
339           id2entry_.push_back(entry);
340           keys_.insert(key);
341           return key;
342         } else {
343           return -1;
344         }
345       } else {
346         return *it;
347       }
348     }
349   }
350
351   const T &FindEntry(I s) const { return id2entry_[s]; }
352
353   I Size() const { return id2entry_.size(); }
354
355   const S &Selector() const { return *selector_; }
356
357   const FP &Fingerprint() const { return *fp_; }
358
359   const H &Hash() const { return *h_; }
360
361  private:
362   static constexpr I kCurrentKey = -1;
363   static constexpr I kEmptyKey = -2;
364
365   class HashFunc {
366    public:
367     explicit HashFunc(const VectorHashBiTable &ht) : ht_(&ht) {}
368
369     size_t operator()(I k) const {
370       if (k >= kCurrentKey) {
371         return (*(ht_->h_))(ht_->Key2Entry(k));
372       } else {
373         return 0;
374       }
375     }
376
377    private:
378     const VectorHashBiTable *ht_;
379   };
380
381   class HashEqual {
382    public:
383     explicit HashEqual(const VectorHashBiTable &ht) : ht_(&ht) {}
384
385     bool operator()(I k1, I k2) const {
386       if (k1 >= kCurrentKey && k2 >= kCurrentKey) {
387         return ht_->Key2Entry(k1) == ht_->Key2Entry(k2);
388       } else {
389         return k1 == k2;
390       }
391     }
392
393    private:
394     const VectorHashBiTable *ht_;
395   };
396
397   using KeyHashSet = HashSet<I, HashFunc, HashEqual, HS>;
398
399   const T &Key2Entry(I k) const {
400     if (k == kCurrentKey) {
401       return *current_entry_;
402     } else {
403       return id2entry_[k];
404     }
405   }
406
407   std::unique_ptr<S> selector_;  // True if entry hashed into vector.
408   std::unique_ptr<FP> fp_;       // Fingerprint used for hashing into vector.
409   std::unique_ptr<H> h_;         // Hash funcion used for hashing into hash_set.
410
411   std::vector<T> id2entry_;  // Maps state IDs to entry.
412   std::vector<I> fp2id_;     // Maps entry fingerprints to IDs.
413
414   // Compact implementation of the hash table mapping entries to state IDs
415   // using the hash function h_.
416   HashFunc hash_func_;
417   HashEqual hash_equal_;
418   KeyHashSet keys_;
419   const T *current_entry_;
420 };
421
422 template <class I, class T, class S, class FP, class H, HSType HS>
423 constexpr I VectorHashBiTable<I, T, S, FP, H, HS>::kCurrentKey;
424
425 template <class I, class T, class S, class FP, class H, HSType HS>
426 constexpr I VectorHashBiTable<I, T, S, FP, H, HS>::kEmptyKey;
427
428 // An implementation using a hash map for the entry to ID mapping. This version
429 // permits erasing of arbitrary states. The entry T must have == defined and
430 // its default constructor must produce a entry that will never be seen. F is
431 // the hash function.
432 template <class I, class T, class F>
433 class ErasableBiTable {
434  public:
435   ErasableBiTable() : first_(0) {}
436
437   I FindId(const T &entry, bool insert = true) {
438     I &id_ref = entry2id_[entry];
439     if (id_ref == 0) {  // T not found.
440       if (insert) {     // Stores and assigns a new ID.
441         id2entry_.push_back(entry);
442         id_ref = id2entry_.size() + first_;
443       } else {
444         return -1;
445       }
446     }
447     return id_ref - 1;  // NB: id_ref = ID + 1.
448   }
449
450   const T &FindEntry(I s) const { return id2entry_[s - first_]; }
451
452   I Size() const { return id2entry_.size(); }
453
454   void Erase(I s) {
455     auto &ref = id2entry_[s - first_];
456     entry2id_.erase(ref);
457     ref = empty_entry_;
458     while (!id2entry_.empty() && id2entry_.front() == empty_entry_) {
459       id2entry_.pop_front();
460       ++first_;
461     }
462   }
463
464  private:
465   std::unordered_map<T, I, F> entry2id_;
466   std::deque<T> id2entry_;
467   const T empty_entry_;
468   I first_;  // I of first element in the deque.
469 };
470
471 }  // namespace fst
472
473 #endif  // FST_LIB_BI_TABLE_H_