Imported Upstream version 1.6.4
[platform/upstream/openfst.git] / src / include / fst / cache.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // An FST implementation that caches FST elements of a delayed computation.
5
6 #ifndef FST_LIB_CACHE_H_
7 #define FST_LIB_CACHE_H_
8
9 #include <functional>
10 #include <unordered_map>
11 using std::unordered_map;
12 using std::unordered_multimap;
13 #include <list>
14 #include <vector>
15
16 #include <fst/log.h>
17
18 #include <fst/vector-fst.h>
19
20
21 DECLARE_bool(fst_default_cache_gc);
22 DECLARE_int64(fst_default_cache_gc_limit);
23
24 namespace fst {
25
26 // Options for controlling caching behavior; higher level than CacheImplOptions.
27 struct CacheOptions {
28   bool gc;          // Enables GC.
29   size_t gc_limit;  // Number of bytes allowed before GC.
30
31   explicit CacheOptions(bool gc = FLAGS_fst_default_cache_gc,
32                         size_t gc_limit = FLAGS_fst_default_cache_gc_limit)
33       : gc(gc), gc_limit(gc_limit) {}
34 };
35
36 // Options for controlling caching behavior, at a lower level than
37 // CacheOptions; templated on the cache store and allows passing the store.
38 template <class CacheStore>
39 struct CacheImplOptions {
40   bool gc;            // Enables GC.
41   size_t gc_limit;    // Number of bytes allowed before GC.
42   CacheStore *store;  // Cache store.
43   bool own_store;     // Should CacheImpl takes ownership of the store?
44
45   explicit CacheImplOptions(bool gc = FLAGS_fst_default_cache_gc,
46                             size_t gc_limit = FLAGS_fst_default_cache_gc_limit,
47                             CacheStore *store = nullptr)
48       : gc(gc), gc_limit(gc_limit), store(store), own_store(true) {}
49
50   explicit CacheImplOptions(const CacheOptions &opts)
51       : gc(opts.gc), gc_limit(opts.gc_limit), store(nullptr), own_store(true) {}
52 };
53
54 // Cache flags.
55 constexpr uint32 kCacheFinal = 0x0001;   // Final weight has been cached.
56 constexpr uint32 kCacheArcs = 0x0002;    // Arcs have been cached.
57 constexpr uint32 kCacheInit = 0x0004;    // Initialized by GC.
58 constexpr uint32 kCacheRecent = 0x0008;  // Visited since GC.
59 constexpr uint32 kCacheFlags =
60     kCacheFinal | kCacheArcs | kCacheInit | kCacheRecent;
61
62 // Cache state, with arcs stored in a per-state std::vector.
63 template <class A, class M = PoolAllocator<A>>
64 class CacheState {
65  public:
66   using Arc = A;
67   using Label = typename Arc::Label;
68   using StateId = typename Arc::StateId;
69   using Weight = typename Arc::Weight;
70
71   using ArcAllocator = M;
72   using StateAllocator =
73       typename ArcAllocator::template rebind<CacheState<A, M>>::other;
74
75   // Provides STL allocator for arcs.
76   explicit CacheState(const ArcAllocator &alloc)
77       : final_(Weight::Zero()),
78         niepsilons_(0),
79         noepsilons_(0),
80         arcs_(alloc),
81         flags_(0),
82         ref_count_(0) {}
83
84   CacheState(const CacheState<A> &state, const ArcAllocator &alloc)
85       : final_(state.Final()),
86         niepsilons_(state.NumInputEpsilons()),
87         noepsilons_(state.NumOutputEpsilons()),
88         arcs_(state.arcs_.begin(), state.arcs_.end(), alloc),
89         flags_(state.Flags()),
90         ref_count_(0) {}
91
92   void Reset() {
93     final_ = Weight::Zero();
94     niepsilons_ = 0;
95     noepsilons_ = 0;
96     ref_count_ = 0;
97     flags_ = 0;
98     arcs_.clear();
99   }
100
101   Weight Final() const { return final_; }
102
103   size_t NumInputEpsilons() const { return niepsilons_; }
104
105   size_t NumOutputEpsilons() const { return noepsilons_; }
106
107   size_t NumArcs() const { return arcs_.size(); }
108
109   const Arc &GetArc(size_t n) const { return arcs_[n]; }
110
111   // Used by the ArcIterator<Fst<Arc>> efficient implementation.
112   const Arc *Arcs() const { return !arcs_.empty() ? &arcs_[0] : nullptr; }
113
114   // Accesses flags; used by the caller.
115   uint32 Flags() const { return flags_; }
116
117   // Accesses ref count; used by the caller.
118   int RefCount() const { return ref_count_; }
119
120   void SetFinal(Weight weight) { final_ = std::move(weight); }
121
122   void ReserveArcs(size_t n) { arcs_.reserve(n); }
123
124   // Adds one arc at a time with all needed book-keeping; use PushArc and
125   // SetArcs for a more efficient alternative.
126   void AddArc(const Arc &arc) {
127     arcs_.push_back(arc);
128     if (arc.ilabel == 0) ++niepsilons_;
129     if (arc.olabel == 0) ++noepsilons_;
130   }
131
132   // Adds one arc at a time with delayed book-keeping; finalize with SetArcs().
133   void PushArc(const Arc &arc) { arcs_.push_back(arc); }
134
135   // Finalizes arcs book-keeping; call only once.
136   void SetArcs() {
137     for (const auto &arc : arcs_) {
138       if (arc.ilabel == 0) ++niepsilons_;
139       if (arc.olabel == 0) ++noepsilons_;
140     }
141   }
142
143   // Modifies nth arc.
144   void SetArc(const Arc &arc, size_t n) {
145     if (arcs_[n].ilabel == 0) --niepsilons_;
146     if (arcs_[n].olabel == 0) --noepsilons_;
147     if (arc.ilabel == 0) ++niepsilons_;
148     if (arc.olabel == 0) ++noepsilons_;
149     arcs_[n] = arc;
150   }
151
152   // Deletes all arcs.
153   void DeleteArcs() {
154     niepsilons_ = 0;
155     noepsilons_ = 0;
156     arcs_.clear();
157   }
158
159   void DeleteArcs(size_t n) {
160     for (size_t i = 0; i < n; ++i) {
161       if (arcs_.back().ilabel == 0) --niepsilons_;
162       if (arcs_.back().olabel == 0) --noepsilons_;
163       arcs_.pop_back();
164     }
165   }
166
167   // Sets status flags; used by the caller.
168   void SetFlags(uint32 flags, uint32 mask) const {
169     flags_ &= ~mask;
170     flags_ |= flags;
171   }
172
173   // Mutates reference counts; used by the caller.
174
175   int IncrRefCount() const { return ++ref_count_; }
176
177   int DecrRefCount() const { return --ref_count_; }
178
179   // Used by the ArcIterator<Fst<Arc>> efficient implementation.
180   int *MutableRefCount() const { return &ref_count_; }
181
182   // Used for state class allocation.
183   void *operator new(size_t size, StateAllocator *alloc) {
184     return alloc->allocate(1);
185   }
186
187   // For state destruction and memory freeing.
188   static void Destroy(CacheState<Arc> *state, StateAllocator *alloc) {
189     if (state) {
190       state->~CacheState<Arc>();
191       alloc->deallocate(state, 1);
192     }
193   }
194
195  private:
196   Weight final_;                         // Final weight.
197   size_t niepsilons_;                    // # of input epsilons.
198   size_t noepsilons_;                    // # of output epsilons.
199   std::vector<Arc, ArcAllocator> arcs_;  // Arcs representation.
200   mutable uint32 flags_;
201   mutable int ref_count_;  // If 0, available for GC.
202 };
203
204 // Cache store, allocating and storing states, providing a mapping from state
205 // IDs to cached states, and an iterator over these states. The state template
206 // argument must implement the CacheState interface. The state for a StateId s
207 // is constructed when requested by GetMutableState(s) if it is not yet stored.
208 // Initially, a state has a reference count of zero, but the user may increment
209 // or decrement this to control the time of destruction. In particular, a state
210 // is destroyed when:
211 //
212 // 1. This instance is destroyed, or
213 // 2. Clear() or Delete() is called, or
214 // 3. Possibly (implementation-dependently) when:
215 //    - Garbage collection is enabled (as defined by opts.gc),
216 //    - The cache store size exceeds the limits (as defined by opts.gc_limits),
217 //    - The state's reference count is zero, and
218 //    - The state is not the most recently requested state.
219 //
220 // template <class S>
221 // class CacheStore {
222 //  public:
223 //   using State = S;
224 //   using Arc = typename State::Arc;
225 //   using StateId = typename Arc::StateId;
226 //
227 //   // Required constructors/assignment operators.
228 //   explicit CacheStore(const CacheOptions &opts);
229 //
230 //   // Returns nullptr if state is not stored.
231 //   const State *GetState(StateId s);
232 //
233 //   // Creates state if state is not stored.
234 //   State *GetMutableState(StateId s);
235 //
236 //   // Similar to State::AddArc() but updates cache store book-keeping.
237 //   void AddArc(State *state, const Arc &arc);
238 //
239 //   // Similar to State::SetArcs() but updates cache store book-keeping; call
240 //   // only once.
241 //   void SetArcs(State *state);
242 //
243 //   // Similar to State::DeleteArcs() but updates cache store book-keeping.
244 //
245 //   void DeleteArcs(State *state);
246 //
247 //   void DeleteArcs(State *state, size_t n);
248 //
249 //   // Deletes all cached states.
250 //   void Clear();
251 //
252 //   // Iterates over cached states (in an arbitrary order); only needed if
253 //   // opts.gc is true.
254 //   bool Done() const;      // End of iteration.
255 //   StateId Value() const;  // Current state.
256 //   void Next();            // Advances to next state (when !Done).
257 //   void Reset();           // Returns to initial condition.
258 //   void Delete();          // Deletes current state and advances to next.
259 // };
260
261 // Container cache stores.
262
263 // This class uses a vector of pointers to states to store cached states.
264 template <class S>
265 class VectorCacheStore {
266  public:
267   using State = S;
268   using Arc = typename State::Arc;
269   using StateId = typename Arc::StateId;
270   using StateList = std::list<StateId, PoolAllocator<StateId>>;
271
272   // Required constructors/assignment operators.
273   explicit VectorCacheStore(const CacheOptions &opts) : cache_gc_(opts.gc) {
274     Clear();
275     Reset();
276   }
277
278   VectorCacheStore(const VectorCacheStore<S> &store)
279       : cache_gc_(store.cache_gc_) {
280     CopyStates(store);
281     Reset();
282   }
283
284   ~VectorCacheStore() { Clear(); }
285
286   VectorCacheStore<State> &operator=(const VectorCacheStore<State> &store) {
287     if (this != &store) {
288       CopyStates(store);
289       Reset();
290     }
291     return *this;
292   }
293
294   // Returns nullptr if state is not stored.
295   const State *GetState(StateId s) const {
296     return s < state_vec_.size() ? state_vec_[s] : nullptr;
297   }
298
299   // Creates state if state is not stored.
300   State *GetMutableState(StateId s) {
301     State *state = nullptr;
302     if (s >= state_vec_.size()) {
303       state_vec_.resize(s + 1, nullptr);
304     } else {
305       state = state_vec_[s];
306     }
307     if (!state) {
308       state = new (&state_alloc_) State(arc_alloc_);
309       state_vec_[s] = state;
310       if (cache_gc_) state_list_.push_back(s);
311     }
312     return state;
313   }
314
315   // Similar to State::AddArc() but updates cache store book-keeping
316   void AddArc(State *state, const Arc &arc) { state->AddArc(arc); }
317
318   // Similar to State::SetArcs() but updates cache store book-keeping; call
319   // only once.
320   void SetArcs(State *state) { state->SetArcs(); }
321
322   // Deletes all arcs.
323   void DeleteArcs(State *state) { state->DeleteArcs(); }
324
325   // Deletes some arcs.
326   void DeleteArcs(State *state, size_t n) { state->DeleteArcs(n); }
327
328   // Deletes all cached states.
329   void Clear() {
330     for (StateId s = 0; s < state_vec_.size(); ++s) {
331       State::Destroy(state_vec_[s], &state_alloc_);
332     }
333     state_vec_.clear();
334     state_list_.clear();
335   }
336
337   // Iterates over cached states (in an arbitrary order); only works if GC is
338   // enabled (o.w. avoiding state_list_ overhead).
339   bool Done() const { return iter_ == state_list_.end(); }
340
341   StateId Value() const { return *iter_; }
342
343   void Next() { ++iter_; }
344
345   void Reset() { iter_ = state_list_.begin(); }
346
347   // Deletes current state and advances to next.
348   void Delete() {
349     State::Destroy(state_vec_[*iter_], &state_alloc_);
350     state_vec_[*iter_] = nullptr;
351     state_list_.erase(iter_++);
352   }
353
354  private:
355   void CopyStates(const VectorCacheStore<State> &store) {
356     Clear();
357     state_vec_.reserve(store.state_vec_.size());
358     for (StateId s = 0; s < store.state_vec_.size(); ++s) {
359       State *state = nullptr;
360       const auto *store_state = store.state_vec_[s];
361       if (store_state) {
362         state = new (&state_alloc_) State(*store_state, arc_alloc_);
363         if (cache_gc_) state_list_.push_back(s);
364       }
365       state_vec_.push_back(state);
366     }
367   }
368
369   bool cache_gc_;                               // Supports iteration when true.
370   std::vector<State *> state_vec_;              // Vector of states (or null).
371   StateList state_list_;                        // List of states.
372   typename StateList::iterator iter_;           // State list iterator.
373   typename State::StateAllocator state_alloc_;  // For state allocation.
374   typename State::ArcAllocator arc_alloc_;      // For arc allocation.
375 };
376
377 // This class uses a hash map from state IDs to pointers to cached states.
378 template <class S>
379 class HashCacheStore {
380  public:
381   using State = S;
382   using Arc = typename State::Arc;
383   using StateId = typename Arc::StateId;
384
385   using StateMap =
386       std::unordered_map<StateId, State *, std::hash<StateId>,
387                          std::equal_to<StateId>,
388                          PoolAllocator<std::pair<const StateId, State *>>>;
389
390   // Required constructors/assignment operators.
391   explicit HashCacheStore(const CacheOptions &opts) {
392     Clear();
393     Reset();
394   }
395
396   HashCacheStore(const HashCacheStore<S> &store) {
397     CopyStates(store);
398     Reset();
399   }
400
401   ~HashCacheStore() { Clear(); }
402
403   HashCacheStore<State> &operator=(const HashCacheStore<State> &store) {
404     if (this != &store) {
405       CopyStates(store);
406       Reset();
407     }
408     return *this;
409   }
410
411   // Returns nullptr if state is not stored.
412   const State *GetState(StateId s) const {
413     const auto it = state_map_.find(s);
414     return it != state_map_.end() ? it->second : nullptr;
415   }
416
417   // Creates state if state is not stored.
418   State *GetMutableState(StateId s) {
419     auto *&state = state_map_[s];
420     if (!state) state = new (&state_alloc_) State(arc_alloc_);
421     return state;
422   }
423
424   // Similar to State::AddArc() but updates cache store book-keeping.
425   void AddArc(State *state, const Arc &arc) { state->AddArc(arc); }
426
427   // Similar to State::SetArcs() but updates internal cache size; call only
428   // once.
429   void SetArcs(State *state) { state->SetArcs(); }
430
431   // Deletes all arcs.
432   void DeleteArcs(State *state) { state->DeleteArcs(); }
433
434   // Deletes some arcs.
435   void DeleteArcs(State *state, size_t n) { state->DeleteArcs(n); }
436
437   // Deletes all cached states.
438   void Clear() {
439     for (auto it = state_map_.begin(); it != state_map_.end(); ++it) {
440       State::Destroy(it->second, &state_alloc_);
441     }
442     state_map_.clear();
443   }
444
445   // Iterates over cached states (in an arbitrary order).
446   bool Done() const { return iter_ == state_map_.end(); }
447
448   StateId Value() const { return iter_->first; }
449
450   void Next() { ++iter_; }
451
452   void Reset() { iter_ = state_map_.begin(); }
453
454   // Deletes current state and advances to next.
455   void Delete() {
456     State::Destroy(iter_->second, &state_alloc_);
457     state_map_.erase(iter_++);
458   }
459
460  private:
461   void CopyStates(const HashCacheStore<State> &store) {
462     Clear();
463     for (auto it = store.state_map_.begin(); it != store.state_map_.end();
464          ++it) {
465       state_map_[it->first] =
466           new (&state_alloc_) State(*it->second, arc_alloc_);
467     }
468   }
469
470   StateMap state_map_;                          // Map from state ID to state.
471   typename StateMap::iterator iter_;            // State map iterator.
472   typename State::StateAllocator state_alloc_;  // For state allocation.
473   typename State::ArcAllocator arc_alloc_;      // For arc allocation.
474 };
475
476 // Garbage-colllection cache stores.
477
478 // This class implements a simple garbage collection scheme when
479 // 'opts.gc_limit = 0'. In particular, the first cached state is reused for each
480 // new state so long as the reference count is zero on the to-be-reused state.
481 // Otherwise, the full underlying store is used. The caller can increment the
482 // reference count to inhibit the GC of in-use states (e.g., in an ArcIterator).
483 //
484 // The typical use case for this optimization is when a single pass over a
485 // cached
486 // FST is performed with only one-state expanded at a time.
487 template <class CacheStore>
488 class FirstCacheStore {
489  public:
490   using State = typename CacheStore::State;
491   using Arc = typename State::Arc;
492   using StateId = typename Arc::StateId;
493
494   // Required constructors/assignment operators.
495   explicit FirstCacheStore(const CacheOptions &opts)
496       : store_(opts),
497         cache_gc_(opts.gc_limit == 0),  // opts.gc ignored historically.
498         cache_first_state_id_(kNoStateId),
499         cache_first_state_(nullptr) {}
500
501   FirstCacheStore(const FirstCacheStore<CacheStore> &store)
502       : store_(store.store_),
503         cache_gc_(store.cache_gc_),
504         cache_first_state_id_(store.cache_first_state_id_),
505         cache_first_state_(store.cache_first_state_id_ != kNoStateId
506                                ? store_.GetMutableState(0)
507                                : nullptr) {}
508
509   FirstCacheStore<CacheStore> &operator=(
510       const FirstCacheStore<CacheStore> &store) {
511     if (this != &store) {
512       store_ = store.store_;
513       cache_gc_ = store.cache_gc_;
514       cache_first_state_id_ = store.cache_first_state_id_;
515       cache_first_state_ = store.cache_first_state_id_ != kNoStateId
516                                ? store_.GetMutableState(0)
517                                : nullptr;
518     }
519     return *this;
520   }
521
522   // Returns nullptr if state is not stored.
523   const State *GetState(StateId s) const {
524     // store_ state 0 may hold first cached state; the rest are shifted by 1.
525     return s == cache_first_state_id_ ? cache_first_state_
526                                       : store_.GetState(s + 1);
527   }
528
529   // Creates state if state is not stored.
530   State *GetMutableState(StateId s) {
531     // store_ state 0 used to hold first cached state; the rest are shifted by
532     // 1.
533     if (cache_first_state_id_ == s) {
534       return cache_first_state_;  // Request for first cached state.
535     }
536     if (cache_gc_) {
537       if (cache_first_state_id_ == kNoStateId) {
538         cache_first_state_id_ = s;  // Sets first cached state.
539         cache_first_state_ = store_.GetMutableState(0);
540         cache_first_state_->SetFlags(kCacheInit, kCacheInit);
541         cache_first_state_->ReserveArcs(2 * kAllocSize);
542         return cache_first_state_;
543       } else if (cache_first_state_->RefCount() == 0) {
544         cache_first_state_id_ = s;  // Updates first cached state.
545         cache_first_state_->Reset();
546         cache_first_state_->SetFlags(kCacheInit, kCacheInit);
547         return cache_first_state_;
548       } else {  // Keeps first cached state.
549         cache_first_state_->SetFlags(0, kCacheInit);  // Clears initialized bit.
550         cache_gc_ = false;                            // Disables GC.
551       }
552     }
553     auto *state = store_.GetMutableState(s + 1);
554     return state;
555   }
556
557   // Similar to State::AddArc() but updates cache store book-keeping.
558   void AddArc(State *state, const Arc &arc) { store_.AddArc(state, arc); }
559
560   // Similar to State::SetArcs() but updates internal cache size; call only
561   // once.
562   void SetArcs(State *state) { store_.SetArcs(state); }
563
564   // Deletes all arcs
565   void DeleteArcs(State *state) { store_.DeleteArcs(state); }
566
567   // Deletes some arcs
568   void DeleteArcs(State *state, size_t n) { store_.DeleteArcs(state, n); }
569
570   // Deletes all cached states
571   void Clear() {
572     store_.Clear();
573     cache_first_state_id_ = kNoStateId;
574     cache_first_state_ = nullptr;
575   }
576
577   // Iterates over cached states (in an arbitrary order). Only needed if GC is
578   // enabled.
579   bool Done() const { return store_.Done(); }
580
581   StateId Value() const {
582     // store_ state 0 may hold first cached state; rest shifted + 1.
583     const auto s = store_.Value();
584     return s ? s - 1 : cache_first_state_id_;
585   }
586
587   void Next() { store_.Next(); }
588
589   void Reset() { store_.Reset(); }
590
591   // Deletes current state and advances to next.
592   void Delete() {
593     if (Value() == cache_first_state_id_) {
594       cache_first_state_id_ = kNoStateId;
595       cache_first_state_ = nullptr;
596     }
597     store_.Delete();
598   }
599
600  private:
601   CacheStore store_;              // Underlying store.
602   bool cache_gc_;                 // GC enabled.
603   StateId cache_first_state_id_;  // First cached state ID.
604   State *cache_first_state_;      // First cached state.
605 };
606
607 // This class implements mark-sweep garbage collection on an underlying cache
608 // store. If GC is enabled, garbage collection of states is performed in a
609 // rough approximation of LRU order once when 'gc_limit' bytes is reached. The
610 // caller can increment the reference count to inhibit the GC of in-use state
611 // (e.g., in an ArcIterator). With GC enabled, the 'gc_limit' parameter allows
612 // the caller to trade-off time vs. space.
613 template <class CacheStore>
614 class GCCacheStore {
615  public:
616   using State = typename CacheStore::State;
617   using Arc = typename State::Arc;
618   using StateId = typename Arc::StateId;
619
620   // Required constructors/assignment operators.
621   explicit GCCacheStore(const CacheOptions &opts)
622       : store_(opts),
623         cache_gc_request_(opts.gc),
624         cache_limit_(opts.gc_limit > kMinCacheLimit ? opts.gc_limit
625                                                     : kMinCacheLimit),
626         cache_gc_(false),
627         cache_size_(0) {}
628
629   // Returns 0 if state is not stored.
630   const State *GetState(StateId s) const { return store_.GetState(s); }
631
632   // Creates state if state is not stored
633   State *GetMutableState(StateId s) {
634     auto *state = store_.GetMutableState(s);
635     if (cache_gc_request_ && !(state->Flags() & kCacheInit)) {
636       state->SetFlags(kCacheInit, kCacheInit);
637       cache_size_ += sizeof(State) + state->NumArcs() * sizeof(Arc);
638       // GC is enabled once an uninited state (from underlying store) is seen.
639       cache_gc_ = true;
640       if (cache_size_ > cache_limit_) GC(state, false);
641     }
642     return state;
643   }
644
645   // Similar to State::AddArc() but updates cache store book-keeping.
646   void AddArc(State *state, const Arc &arc) {
647     store_.AddArc(state, arc);
648     if (cache_gc_ && (state->Flags() & kCacheInit)) {
649       cache_size_ += sizeof(Arc);
650       if (cache_size_ > cache_limit_) GC(state, false);
651     }
652   }
653
654   // Similar to State::SetArcs() but updates internal cache size; call only
655   // once.
656   void SetArcs(State *state) {
657     store_.SetArcs(state);
658     if (cache_gc_ && (state->Flags() & kCacheInit)) {
659       cache_size_ += state->NumArcs() * sizeof(Arc);
660       if (cache_size_ > cache_limit_) GC(state, false);
661     }
662   }
663
664   // Deletes all arcs.
665   void DeleteArcs(State *state) {
666     if (cache_gc_ && (state->Flags() & kCacheInit)) {
667       cache_size_ -= state->NumArcs() * sizeof(Arc);
668     }
669     store_.DeleteArcs(state);
670   }
671
672   // Deletes some arcs.
673   void DeleteArcs(State *state, size_t n) {
674     if (cache_gc_ && (state->Flags() & kCacheInit)) {
675       cache_size_ -= n * sizeof(Arc);
676     }
677     store_.DeleteArcs(state, n);
678   }
679
680   // Deletes all cached states.
681   void Clear() {
682     store_.Clear();
683     cache_size_ = 0;
684   }
685
686   // Iterates over cached states (in an arbitrary order); only needed if GC is
687   // enabled.
688   bool Done() const { return store_.Done(); }
689
690   StateId Value() const { return store_.Value(); }
691
692   void Next() { store_.Next(); }
693
694   void Reset() { store_.Reset(); }
695
696   // Deletes current state and advances to next.
697   void Delete() {
698     if (cache_gc_) {
699       const auto *state = store_.GetState(Value());
700       if (state->Flags() & kCacheInit) {
701         cache_size_ -= sizeof(State) + state->NumArcs() * sizeof(Arc);
702       }
703     }
704     store_.Delete();
705   }
706
707   // Removes from the cache store (not referenced-counted and not the current)
708   // states that have not been accessed since the last GC until at most
709   // cache_fraction * cache_limit_ bytes are cached. If that fails to free
710   // enough, attempts to uncaching recently visited states as well. If still
711   // unable to free enough memory, then widens cache_limit_.
712   void GC(const State *current, bool free_recent, float cache_fraction = 0.666);
713
714   // Returns the current cache size in bytes or 0 if GC is disabled.
715   size_t CacheSize() const { return cache_size_; }
716
717   // Returns the cache limit in bytes.
718   size_t CacheLimit() const { return cache_limit_; }
719
720  private:
721   static constexpr size_t kMinCacheLimit = 8096;  // Minimum cache limit.
722
723   CacheStore store_;       // Underlying store.
724   bool cache_gc_request_;  // GC requested but possibly not yet enabled.
725   size_t cache_limit_;     // Number of bytes allowed before GC.
726   bool cache_gc_;          // GC enabled
727   size_t cache_size_;      // Number of bytes cached.
728 };
729
730 template <class CacheStore>
731 void GCCacheStore<CacheStore>::GC(const State *current, bool free_recent,
732                                   float cache_fraction) {
733   if (!cache_gc_) return;
734   VLOG(2) << "GCCacheStore: Enter GC: object = "
735           << "(" << this << "), free recently cached = " << free_recent
736           << ", cache size = " << cache_size_
737           << ", cache frac = " << cache_fraction
738           << ", cache limit = " << cache_limit_ << "\n";
739   size_t cache_target = cache_fraction * cache_limit_;
740   store_.Reset();
741   while (!store_.Done()) {
742     auto *state = store_.GetMutableState(store_.Value());
743     if (cache_size_ > cache_target && state->RefCount() == 0 &&
744         (free_recent || !(state->Flags() & kCacheRecent)) && state != current) {
745       if (state->Flags() & kCacheInit) {
746         size_t size = sizeof(State) + state->NumArcs() * sizeof(Arc);
747         if (size < cache_size_) {
748           cache_size_ -= size;
749         }
750       }
751       store_.Delete();
752     } else {
753       state->SetFlags(0, kCacheRecent);
754       store_.Next();
755     }
756   }
757   if (!free_recent && cache_size_ > cache_target) {  // Recurses on recent.
758     GC(current, true, cache_fraction);
759   } else if (cache_target > 0) {  // Widens cache limit.
760     while (cache_size_ > cache_target) {
761       cache_limit_ *= 2;
762       cache_target *= 2;
763     }
764   } else if (cache_size_ > 0) {
765     FSTERROR() << "GCCacheStore:GC: Unable to free all cached states";
766   }
767   VLOG(2) << "GCCacheStore: Exit GC: object = "
768           << "(" << this << "), free recently cached = " << free_recent
769           << ", cache size = " << cache_size_
770           << ", cache frac = " << cache_fraction
771           << ", cache limit = " << cache_limit_ << "\n";
772 }
773
774 template <class CacheStore>
775 constexpr size_t GCCacheStore<CacheStore>::kMinCacheLimit;
776
777 // This class is the default cache state and store used by CacheBaseImpl.
778 // It uses VectorCacheStore for storage decorated by FirstCacheStore
779 // and GCCacheStore to do (optional) garbage collection.
780 template <class Arc>
781 class DefaultCacheStore
782     : public GCCacheStore<FirstCacheStore<VectorCacheStore<CacheState<Arc>>>> {
783  public:
784   explicit DefaultCacheStore(const CacheOptions &opts)
785       : GCCacheStore<FirstCacheStore<VectorCacheStore<CacheState<Arc>>>>(opts) {
786   }
787 };
788
789 namespace internal {
790
791 // This class is used to cache FST elements stored in states of type State
792 // (see CacheState) with the flags used to indicate what has been cached. Use
793 // HasStart(), HasFinal(), and HasArcs() to determine if cached and SetStart(),
794 // SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note that you
795 // must set the final weight even if the state is non-final to mark it as
796 // cached. The state storage method and any garbage collection policy are
797 // determined by the cache store. If the store is passed in with the options,
798 // CacheBaseImpl takes ownership.
799 template <class State,
800           class CacheStore = DefaultCacheStore<typename State::Arc>>
801 class CacheBaseImpl : public FstImpl<typename State::Arc> {
802  public:
803   using Arc = typename State::Arc;
804   using StateId = typename Arc::StateId;
805   using Weight = typename Arc::Weight;
806
807   using Store = CacheStore;
808
809   using FstImpl<Arc>::Type;
810   using FstImpl<Arc>::Properties;
811
812   explicit CacheBaseImpl(const CacheOptions &opts = CacheOptions())
813       : has_start_(false),
814         cache_start_(kNoStateId),
815         nknown_states_(0),
816         min_unexpanded_state_id_(0),
817         max_expanded_state_id_(-1),
818         cache_gc_(opts.gc),
819         cache_limit_(opts.gc_limit),
820         cache_store_(new CacheStore(opts)),
821         new_cache_store_(true),
822         own_cache_store_(true) {}
823
824   explicit CacheBaseImpl(const CacheImplOptions<CacheStore> &opts)
825       : has_start_(false),
826         cache_start_(kNoStateId),
827         nknown_states_(0),
828         min_unexpanded_state_id_(0),
829         max_expanded_state_id_(-1),
830         cache_gc_(opts.gc),
831         cache_limit_(opts.gc_limit),
832         cache_store_(opts.store ? opts.store : new CacheStore(CacheOptions(
833                                                    opts.gc, opts.gc_limit))),
834         new_cache_store_(!opts.store),
835         own_cache_store_(opts.store ? opts.own_store : true) {}
836
837   // Preserve gc parameters. If preserve_cache is true, also preserves
838   // cache data.
839   CacheBaseImpl(const CacheBaseImpl<State, CacheStore> &impl,
840                 bool preserve_cache = false)
841       : FstImpl<Arc>(),
842         has_start_(false),
843         cache_start_(kNoStateId),
844         nknown_states_(0),
845         min_unexpanded_state_id_(0),
846         max_expanded_state_id_(-1),
847         cache_gc_(impl.cache_gc_),
848         cache_limit_(impl.cache_limit_),
849         cache_store_(new CacheStore(CacheOptions(cache_gc_, cache_limit_))),
850         new_cache_store_(impl.new_cache_store_ || !preserve_cache),
851         own_cache_store_(true) {
852     if (preserve_cache) {
853       *cache_store_ = *impl.cache_store_;
854       has_start_ = impl.has_start_;
855       cache_start_ = impl.cache_start_;
856       nknown_states_ = impl.nknown_states_;
857       expanded_states_ = impl.expanded_states_;
858       min_unexpanded_state_id_ = impl.min_unexpanded_state_id_;
859       max_expanded_state_id_ = impl.max_expanded_state_id_;
860     }
861   }
862
863   ~CacheBaseImpl() override { if (own_cache_store_) delete cache_store_; }
864
865   void SetStart(StateId s) {
866     cache_start_ = s;
867     has_start_ = true;
868     if (s >= nknown_states_) nknown_states_ = s + 1;
869   }
870
871   void SetFinal(StateId s, Weight weight) {
872     auto *state = cache_store_->GetMutableState(s);
873     state->SetFinal(std::move(weight));
874     static constexpr auto flags = kCacheFinal | kCacheRecent;
875     state->SetFlags(flags, flags);
876   }
877
878 // Disabled to ensure PushArc not AddArc is used in existing code
879 // TODO(sorenj): re-enable for backing store
880 #if 0
881   // AddArc adds a single arc to a state and does incremental cache
882   // book-keeping. For efficiency, prefer PushArc and SetArcs below
883   // when possible.
884   void AddArc(StateId s, const Arc &arc) {
885     auto *state = cache_store_->GetMutableState(s);
886     cache_store_->AddArc(state, arc);
887     if (arc.nextstate >= nknown_states_)
888       nknown_states_ = arc.nextstate + 1;
889     SetExpandedState(s);
890     static constexpr auto flags = kCacheArcs | kCacheRecent;
891     state->SetFlags(flags, flags);
892   }
893 #endif
894
895   // Adds a single arc to a state but delays cache book-keeping. SetArcs must
896   // be called when all PushArc calls at a state are complete. Do not mix with
897   // calls to AddArc.
898   void PushArc(StateId s, const Arc &arc) {
899     auto *state = cache_store_->GetMutableState(s);
900     state->PushArc(arc);
901   }
902
903   // Marks arcs of a state as cached and does cache book-keeping after all
904   // calls to PushArc have been completed. Do not mix with calls to AddArc.
905   void SetArcs(StateId s) {
906     auto *state = cache_store_->GetMutableState(s);
907     cache_store_->SetArcs(state);
908     const auto narcs = state->NumArcs();
909     for (size_t a = 0; a < narcs; ++a) {
910       const auto &arc = state->GetArc(a);
911       if (arc.nextstate >= nknown_states_) nknown_states_ = arc.nextstate + 1;
912     }
913     SetExpandedState(s);
914     static constexpr auto flags = kCacheArcs | kCacheRecent;
915     state->SetFlags(flags, flags);
916   }
917
918   void ReserveArcs(StateId s, size_t n) {
919     auto *state = cache_store_->GetMutableState(s);
920     state->ReserveArcs(n);
921   }
922
923   void DeleteArcs(StateId s) {
924     auto *state = cache_store_->GetMutableState(s);
925     cache_store_->DeleteArcs(state);
926   }
927
928   void DeleteArcs(StateId s, size_t n) {
929     auto *state = cache_store_->GetMutableState(s);
930     cache_store_->DeleteArcs(state, n);
931   }
932
933   void Clear() {
934     nknown_states_ = 0;
935     min_unexpanded_state_id_ = 0;
936     max_expanded_state_id_ = -1;
937     has_start_ = false;
938     cache_start_ = kNoStateId;
939     cache_store_->Clear();
940   }
941
942   // Is the start state cached?
943   bool HasStart() const {
944     if (!has_start_ && Properties(kError)) has_start_ = true;
945     return has_start_;
946   }
947
948   // Is the final weight of the state cached?
949   bool HasFinal(StateId s) const {
950     const auto *state = cache_store_->GetState(s);
951     if (state && state->Flags() & kCacheFinal) {
952       state->SetFlags(kCacheRecent, kCacheRecent);
953       return true;
954     } else {
955       return false;
956     }
957   }
958
959   // Are arcs of the state cached?
960   bool HasArcs(StateId s) const {
961     const auto *state = cache_store_->GetState(s);
962     if (state && state->Flags() & kCacheArcs) {
963       state->SetFlags(kCacheRecent, kCacheRecent);
964       return true;
965     } else {
966       return false;
967     }
968   }
969
970   StateId Start() const { return cache_start_; }
971
972   Weight Final(StateId s) const {
973     const auto *state = cache_store_->GetState(s);
974     return state->Final();
975   }
976
977   size_t NumArcs(StateId s) const {
978     const auto *state = cache_store_->GetState(s);
979     return state->NumArcs();
980   }
981
982   size_t NumInputEpsilons(StateId s) const {
983     const auto *state = cache_store_->GetState(s);
984     return state->NumInputEpsilons();
985   }
986
987   size_t NumOutputEpsilons(StateId s) const {
988     const auto *state = cache_store_->GetState(s);
989     return state->NumOutputEpsilons();
990   }
991
992   // Provides information needed for generic arc iterator.
993   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
994     const auto *state = cache_store_->GetState(s);
995     data->base = nullptr;
996     data->narcs = state->NumArcs();
997     data->arcs = state->Arcs();
998     data->ref_count = state->MutableRefCount();
999     state->IncrRefCount();
1000   }
1001
1002   // Number of known states.
1003   StateId NumKnownStates() const { return nknown_states_; }
1004
1005   // Updates number of known states, taking into account the passed state ID.
1006   void UpdateNumKnownStates(StateId s) {
1007     if (s >= nknown_states_) nknown_states_ = s + 1;
1008   }
1009
1010   // Finds the mininum never-expanded state ID.
1011   StateId MinUnexpandedState() const {
1012     while (min_unexpanded_state_id_ <= max_expanded_state_id_ &&
1013            ExpandedState(min_unexpanded_state_id_)) {
1014       ++min_unexpanded_state_id_;
1015     }
1016     return min_unexpanded_state_id_;
1017   }
1018
1019   // Returns maximum ever-expanded state ID.
1020   StateId MaxExpandedState() const { return max_expanded_state_id_; }
1021
1022   void SetExpandedState(StateId s) {
1023     if (s > max_expanded_state_id_) max_expanded_state_id_ = s;
1024     if (s < min_unexpanded_state_id_) return;
1025     if (s == min_unexpanded_state_id_) ++min_unexpanded_state_id_;
1026     if (cache_gc_ || cache_limit_ == 0) {
1027       while (expanded_states_.size() <= s) expanded_states_.push_back(false);
1028       expanded_states_[s] = true;
1029     }
1030   }
1031
1032   bool ExpandedState(StateId s) const {
1033     if (cache_gc_ || cache_limit_ == 0) {
1034       return expanded_states_[s];
1035     } else if (new_cache_store_) {
1036       return cache_store_->GetState(s) != nullptr;
1037     } else {
1038       // If the cache was not created by this class, then the cached state needs
1039       // to be inspected to update nknown_states_.
1040       return false;
1041     }
1042   }
1043
1044   const CacheStore *GetCacheStore() const { return cache_store_; }
1045
1046   CacheStore *GetCacheStore() { return cache_store_; }
1047
1048   // Caching on/off switch, limit and size accessors.
1049
1050   bool GetCacheGc() const { return cache_gc_; }
1051
1052   size_t GetCacheLimit() const { return cache_limit_; }
1053
1054  private:
1055   mutable bool has_start_;                   // Is the start state cached?
1056   StateId cache_start_;                      // ID of start state.
1057   StateId nknown_states_;                    // Number of known states.
1058   std::vector<bool> expanded_states_;        // States that have been expanded.
1059   mutable StateId min_unexpanded_state_id_;  // Minimum never-expanded state ID
1060   mutable StateId max_expanded_state_id_;    // Maximum ever-expanded state ID
1061   bool cache_gc_;                            // GC enabled.
1062   size_t cache_limit_;       // Number of bytes allowed before GC.
1063   CacheStore *cache_store_;  // The store of cached states.
1064   bool new_cache_store_;     // Was the store was created by class?
1065   bool own_cache_store_;     // Is the store owned by class?
1066
1067   CacheBaseImpl &operator=(const CacheBaseImpl &impl) = delete;
1068 };
1069
1070 // A CacheBaseImpl with the default cache state type.
1071 template <class Arc>
1072 class CacheImpl : public CacheBaseImpl<CacheState<Arc>> {
1073  public:
1074   using State = CacheState<Arc>;
1075
1076   CacheImpl() {}
1077
1078   explicit CacheImpl(const CacheOptions &opts)
1079       : CacheBaseImpl<CacheState<Arc>>(opts) {}
1080
1081   CacheImpl(const CacheImpl<Arc> &impl, bool preserve_cache = false)
1082       : CacheBaseImpl<State>(impl, preserve_cache) {}
1083
1084  private:
1085   CacheImpl &operator=(const CacheImpl &impl) = delete;
1086 };
1087
1088 }  // namespace internal
1089
1090 // Use this to make a state iterator for a CacheBaseImpl-derived FST, which must
1091 // have Arc and Store types defined. Note this iterator only returns those
1092 // states reachable from the initial state, so consider implementing a
1093 // class-specific one.
1094 //
1095 // This class may be derived from.
1096 template <class FST>
1097 class CacheStateIterator : public StateIteratorBase<typename FST::Arc> {
1098  public:
1099   using Arc = typename FST::Arc;
1100   using StateId = typename Arc::StateId;
1101   using Weight = typename Arc::Weight;
1102
1103   using Store = typename FST::Store;
1104   using State = typename Store::State;
1105   using Impl = internal::CacheBaseImpl<State, Store>;
1106
1107   CacheStateIterator(const FST &fst, Impl *impl)
1108       : fst_(fst), impl_(impl), s_(0) {
1109     fst_.Start();  // Forces start state.
1110   }
1111
1112   bool Done() const final {
1113     if (s_ < impl_->NumKnownStates()) return false;
1114     for (StateId u = impl_->MinUnexpandedState(); u < impl_->NumKnownStates();
1115          u = impl_->MinUnexpandedState()) {
1116       // Forces state expansion.
1117       ArcIterator<FST> aiter(fst_, u);
1118       aiter.SetFlags(kArcValueFlags, kArcValueFlags | kArcNoCache);
1119       for (; !aiter.Done(); aiter.Next()) {
1120         impl_->UpdateNumKnownStates(aiter.Value().nextstate);
1121       }
1122       impl_->SetExpandedState(u);
1123       if (s_ < impl_->NumKnownStates()) return false;
1124     }
1125     return true;
1126   }
1127
1128   StateId Value() const final { return s_; }
1129
1130   void Next() final { ++s_; }
1131
1132   void Reset() final { s_ = 0; }
1133
1134  private:
1135   const FST &fst_;
1136   Impl *impl_;
1137   StateId s_;
1138 };
1139
1140 // Used to make an arc iterator for a CacheBaseImpl-derived FST, which must
1141 // have Arc and State types defined.
1142 template <class FST>
1143 class CacheArcIterator {
1144  public:
1145   using Arc = typename FST::Arc;
1146   using StateId = typename Arc::StateId;
1147   using Weight = typename Arc::Weight;
1148
1149   using Store = typename FST::Store;
1150   using State = typename Store::State;
1151   using Impl = internal::CacheBaseImpl<State, Store>;
1152
1153   CacheArcIterator(Impl *impl, StateId s) : i_(0) {
1154     state_ = impl->GetCacheStore()->GetMutableState(s);
1155     state_->IncrRefCount();
1156   }
1157
1158   ~CacheArcIterator() { state_->DecrRefCount(); }
1159
1160   bool Done() const { return i_ >= state_->NumArcs(); }
1161
1162   const Arc &Value() const { return state_->GetArc(i_); }
1163
1164   void Next() { ++i_; }
1165
1166   size_t Position() const { return i_; }
1167
1168   void Reset() { i_ = 0; }
1169
1170   void Seek(size_t a) { i_ = a; }
1171
1172   constexpr uint32 Flags() const { return kArcValueFlags; }
1173
1174   void SetFlags(uint32 flags, uint32 mask) {}
1175
1176  private:
1177   const State *state_;
1178   size_t i_;
1179
1180   CacheArcIterator(const CacheArcIterator &) = delete;
1181   CacheArcIterator &operator=(const CacheArcIterator &) = delete;
1182 };
1183
1184 // Use this to make a mutable arc iterator for a CacheBaseImpl-derived FST,
1185 // which must have types Arc and Store defined.
1186 template <class FST>
1187 class CacheMutableArcIterator
1188     : public MutableArcIteratorBase<typename FST::Arc> {
1189  public:
1190   using Arc = typename FST::Arc;
1191   using StateId = typename Arc::StateId;
1192   using Weight = typename Arc::Weight;
1193
1194   using Store = typename FST::Store;
1195   using State = typename Store::State;
1196   using Impl = internal::CacheBaseImpl<State, Store>;
1197
1198   // User must call MutateCheck() in the constructor.
1199   CacheMutableArcIterator(Impl *impl, StateId s) : i_(0), s_(s), impl_(impl) {
1200     state_ = impl_->GetCacheStore()->GetMutableState(s_);
1201     state_->IncrRefCount();
1202   }
1203
1204   ~CacheMutableArcIterator() override { state_->DecrRefCount(); }
1205
1206   bool Done() const final { return i_ >= state_->NumArcs(); }
1207
1208   const Arc &Value() const final { return state_->GetArc(i_); }
1209
1210   void Next() final { ++i_; }
1211
1212   size_t Position() const final { return i_; }
1213
1214   void Reset() final { i_ = 0; }
1215
1216   void Seek(size_t a) final { i_ = a; }
1217
1218   void SetValue(const Arc &arc) final { state_->SetArc(arc, i_); }
1219
1220   uint32 Flags() const final { return kArcValueFlags; }
1221
1222   void SetFlags(uint32, uint32) final {}
1223
1224  private:
1225   size_t i_;
1226   StateId s_;
1227   Impl *impl_;
1228   State *state_;
1229
1230   CacheMutableArcIterator(const CacheMutableArcIterator &) = delete;
1231   CacheMutableArcIterator &operator=(const CacheMutableArcIterator &) = delete;
1232 };
1233
1234 // Wrap existing CacheStore implementation to use with ExpanderFst.
1235 template <class CacheStore>
1236 class ExpanderCacheStore {
1237  public:
1238   using State = typename CacheStore::State;
1239   using Arc = typename CacheStore::Arc;
1240   using StateId = typename Arc::StateId;
1241   using Weight = typename Arc::Weight;
1242
1243   explicit ExpanderCacheStore(const CacheOptions &opts = CacheOptions())
1244       : store_(opts) {}
1245
1246   template <class Expander>
1247   State *FindOrExpand(Expander &expander, StateId s) {  // NOLINT
1248     auto *state = store_.GetMutableState(s);
1249     if (state->Flags()) {
1250       state->SetFlags(kCacheRecent, kCacheRecent);
1251     } else {
1252       StateBuilder builder(state);
1253       expander.Expand(s, &builder);
1254       state->SetFlags(kCacheFlags, kCacheFlags);
1255       store_.SetArcs(state);
1256     }
1257     return state;
1258   }
1259
1260  private:
1261   CacheStore store_;
1262
1263   struct StateBuilder {
1264     State *state;
1265
1266     explicit StateBuilder(State *state_) : state(state_) {}
1267
1268     void AddArc(const Arc &arc) { state->PushArc(arc); }
1269
1270     void SetFinal(Weight weight) { state->SetFinal(std::move(weight)); }
1271   };
1272 };
1273
1274 }  // namespace fst
1275
1276 #endif  // FST_LIB_CACHE_H_