8c528c06a79a95dfeb9e17da284a4e44578e4e8e
[platform/upstream/openfst.git] / src / include / fst / cache.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // An FST implementation that caches FST elements of a delayed computation.
5
6 #ifndef FST_LIB_CACHE_H_
7 #define FST_LIB_CACHE_H_
8
9 #include <functional>
10 #include <unordered_map>
11 using std::unordered_map;
12 using std::unordered_multimap;
13 #include <list>
14 #include <vector>
15
16 #include <fst/log.h>
17
18 #include <fst/vector-fst.h>
19
20
21 DECLARE_bool(fst_default_cache_gc);
22 DECLARE_int64(fst_default_cache_gc_limit);
23
24 namespace fst {
25
26 // Options for controlling caching behavior; higher level than CacheImplOptions.
27 struct CacheOptions {
28   bool gc;          // Enables GC.
29   size_t gc_limit;  // Number of bytes allowed before GC.
30
31   explicit CacheOptions(bool gc = FLAGS_fst_default_cache_gc,
32                         size_t gc_limit = FLAGS_fst_default_cache_gc_limit)
33       : gc(gc), gc_limit(gc_limit) {}
34 };
35
36 // Options for controlling caching behavior, at a lower level than
37 // CacheOptions; templated on the cache store and allows passing the store.
38 template <class CacheStore>
39 struct CacheImplOptions {
40   bool gc;            // Enables GC.
41   size_t gc_limit;    // Number of bytes allowed before GC.
42   CacheStore *store;  // Cache store.
43   bool own_store;     // Should CacheImpl takes ownership of the store?
44
45   explicit CacheImplOptions(bool gc = FLAGS_fst_default_cache_gc,
46                             size_t gc_limit = FLAGS_fst_default_cache_gc_limit,
47                             CacheStore *store = nullptr)
48       : gc(gc), gc_limit(gc_limit), store(store), own_store(true) {}
49
50   explicit CacheImplOptions(const CacheOptions &opts)
51       : gc(opts.gc), gc_limit(opts.gc_limit), store(nullptr), own_store(true) {}
52 };
53
54 // Cache flags.
55 constexpr uint32 kCacheFinal = 0x0001;   // Final weight has been cached.
56 constexpr uint32 kCacheArcs = 0x0002;    // Arcs have been cached.
57 constexpr uint32 kCacheInit = 0x0004;    // Initialized by GC.
58 constexpr uint32 kCacheRecent = 0x0008;  // Visited since GC.
59 constexpr uint32 kCacheFlags =
60     kCacheFinal | kCacheArcs | kCacheInit | kCacheRecent;
61
62 // Cache state, with arcs stored in a per-state std::vector.
63 template <class A, class M = PoolAllocator<A>>
64 class CacheState {
65  public:
66   using Arc = A;
67   using Label = typename Arc::Label;
68   using StateId = typename Arc::StateId;
69   using Weight = typename Arc::Weight;
70
71   using ArcAllocator = M;
72   using StateAllocator =
73       typename ArcAllocator::template rebind<CacheState<A, M>>::other;
74
75   // Provides STL allocator for arcs.
76   explicit CacheState(const ArcAllocator &alloc)
77       : final_(Weight::Zero()),
78         niepsilons_(0),
79         noepsilons_(0),
80         arcs_(alloc),
81         flags_(0),
82         ref_count_(0) {}
83
84   CacheState(const CacheState<A> &state, const ArcAllocator &alloc)
85       : final_(state.Final()),
86         niepsilons_(state.NumInputEpsilons()),
87         noepsilons_(state.NumOutputEpsilons()),
88         arcs_(state.arcs_.begin(), state.arcs_.end(), alloc),
89         flags_(state.Flags()),
90         ref_count_(0) {}
91
92   void Reset() {
93     final_ = Weight::Zero();
94     niepsilons_ = 0;
95     noepsilons_ = 0;
96     ref_count_ = 0;
97     flags_ = 0;
98     arcs_.clear();
99   }
100
101   Weight Final() const { return final_; }
102
103   size_t NumInputEpsilons() const { return niepsilons_; }
104
105   size_t NumOutputEpsilons() const { return noepsilons_; }
106
107   size_t NumArcs() const { return arcs_.size(); }
108
109   const Arc &GetArc(size_t n) const { return arcs_[n]; }
110
111   // Used by the ArcIterator<Fst<Arc>> efficient implementation.
112   const Arc *Arcs() const { return !arcs_.empty() ? &arcs_[0] : nullptr; }
113
114   // Accesses flags; used by the caller.
115   uint32 Flags() const { return flags_; }
116
117   // Accesses ref count; used by the caller.
118   int RefCount() const { return ref_count_; }
119
120   void SetFinal(Weight weight) { final_ = std::move(weight); }
121
122   void ReserveArcs(size_t n) { arcs_.reserve(n); }
123
124   // Adds one arc at a time with all needed book-keeping; use PushArc and
125   // SetArcs for a more efficient alternative.
126   void AddArc(const Arc &arc) {
127     arcs_.push_back(arc);
128     if (arc.ilabel == 0) ++niepsilons_;
129     if (arc.olabel == 0) ++noepsilons_;
130   }
131
132   // Adds one arc at a time with delayed book-keeping; finalize with SetArcs().
133   void PushArc(const Arc &arc) { arcs_.push_back(arc); }
134
135   // Finalizes arcs book-keeping; call only once.
136   void SetArcs() {
137     for (const auto &arc : arcs_) {
138       if (arc.ilabel == 0) ++niepsilons_;
139       if (arc.olabel == 0) ++noepsilons_;
140     }
141   }
142
143   // Modifies nth arc.
144   void SetArc(const Arc &arc, size_t n) {
145     if (arcs_[n].ilabel == 0) --niepsilons_;
146     if (arcs_[n].olabel == 0) --noepsilons_;
147     if (arc.ilabel == 0) ++niepsilons_;
148     if (arc.olabel == 0) ++noepsilons_;
149     arcs_[n] = arc;
150   }
151
152   // Deletes all arcs.
153   void DeleteArcs() {
154     niepsilons_ = 0;
155     noepsilons_ = 0;
156     arcs_.clear();
157   }
158
159   void DeleteArcs(size_t n) {
160     for (size_t i = 0; i < n; ++i) {
161       if (arcs_.back().ilabel == 0) --niepsilons_;
162       if (arcs_.back().olabel == 0) --noepsilons_;
163       arcs_.pop_back();
164     }
165   }
166
167   // Sets status flags; used by the caller.
168   void SetFlags(uint32 flags, uint32 mask) const {
169     flags_ &= ~mask;
170     flags_ |= flags;
171   }
172
173   // Mutates reference counts; used by the caller.
174
175   int IncrRefCount() const { return ++ref_count_; }
176
177   int DecrRefCount() const { return --ref_count_; }
178
179   // Used by the ArcIterator<Fst<Arc>> efficient implementation.
180   int *MutableRefCount() const { return &ref_count_; }
181
182   // Used for state class allocation.
183   void *operator new(size_t size, StateAllocator *alloc) {
184     return alloc->allocate(1);
185   }
186
187   // For state destruction and memory freeing.
188   static void Destroy(CacheState<Arc> *state, StateAllocator *alloc) {
189     if (state) {
190       state->~CacheState<Arc>();
191       alloc->deallocate(state, 1);
192     }
193   }
194
195  private:
196   Weight final_;                         // Final weight.
197   size_t niepsilons_;                    // # of input epsilons.
198   size_t noepsilons_;                    // # of output epsilons.
199   std::vector<Arc, ArcAllocator> arcs_;  // Arcs representation.
200   mutable uint32 flags_;
201   mutable int ref_count_;  // If 0, available for GC.
202 };
203
204 // Cache store, allocating and storing states, providing a mapping from state
205 // IDs to cached states, and an iterator over these states. The state template
206 // argument must implement the CacheState interface. The state for a StateId s
207 // is constructed when requested by GetMutableState(s) if it is not yet stored.
208 // Initially, a state has a reference count of zero, but the user may increment
209 // or decrement this to control the time of destruction. In particular, a state
210 // is destroyed when:
211 //
212 // 1. This instance is destroyed, or
213 // 2. Clear() or Delete() is called, or
214 // 3. Possibly (implementation-dependently) when:
215 //    - Garbage collection is enabled (as defined by opts.gc),
216 //    - The cache store size exceeds the limits (as defined by opts.gc_limits),
217 //    - The state's reference count is zero, and
218 //    - The state is not the most recently requested state.
219 //
220 // template <class S>
221 // class CacheStore {
222 //  public:
223 //   using State = S;
224 //   using Arc = typename State::Arc;
225 //   using StateId = typename Arc::StateId;
226 //
227 //   // Required constructors/assignment operators.
228 //   explicit CacheStore(const CacheOptions &opts);
229 //
230 //   // Returns nullptr if state is not stored.
231 //   const State *GetState(StateId s);
232 //
233 //   // Creates state if state is not stored.
234 //   State *GetMutableState(StateId s);
235 //
236 //   // Similar to State::AddArc() but updates cache store book-keeping.
237 //   void AddArc(State *state, const Arc &arc);
238 //
239 //   // Similar to State::SetArcs() but updates cache store book-keeping; call
240 //   // only once.
241 //   void SetArcs(State *state);
242 //
243 //   // Similar to State::DeleteArcs() but updates cache store book-keeping.
244 //
245 //   void DeleteArcs(State *state);
246 //
247 //   void DeleteArcs(State *state, size_t n);
248 //
249 //   // Deletes all cached states.
250 //   void Clear();
251 //
252 //   // Iterates over cached states (in an arbitrary order); only needed if
253 //   // opts.gc is true.
254 //   bool Done() const;      // End of iteration.
255 //   StateId Value() const;  // Current state.
256 //   void Next();            // Advances to next state (when !Done).
257 //   void Reset();           // Returns to initial condition.
258 //   void Delete();          // Deletes current state and advances to next.
259 // };
260
261 // Container cache stores.
262
263 // This class uses a vector of pointers to states to store cached states.
264 template <class S>
265 class VectorCacheStore {
266  public:
267   using State = S;
268   using Arc = typename State::Arc;
269   using StateId = typename Arc::StateId;
270   using StateList = std::list<StateId, PoolAllocator<StateId>>;
271
272   // Required constructors/assignment operators.
273   explicit VectorCacheStore(const CacheOptions &opts) : cache_gc_(opts.gc) {
274     Clear();
275     Reset();
276   }
277
278   VectorCacheStore(const VectorCacheStore<S> &store)
279       : cache_gc_(store.cache_gc_) {
280     CopyStates(store);
281     Reset();
282   }
283
284   ~VectorCacheStore() { Clear(); }
285
286   VectorCacheStore<State> &operator=(const VectorCacheStore<State> &store) {
287     if (this != &store) {
288       CopyStates(store);
289       Reset();
290     }
291     return *this;
292   }
293
294   // Returns nullptr if state is not stored.
295   const State *GetState(StateId s) const {
296     return s < state_vec_.size() ? state_vec_[s] : nullptr;
297   }
298
299   // Creates state if state is not stored.
300   State *GetMutableState(StateId s) {
301     State *state = nullptr;
302     if (s >= state_vec_.size()) {
303       state_vec_.resize(s + 1, nullptr);
304     } else {
305       state = state_vec_[s];
306     }
307     if (!state) {
308       state = new (&state_alloc_) State(arc_alloc_);
309       state_vec_[s] = state;
310       if (cache_gc_) state_list_.push_back(s);
311     }
312     return state;
313   }
314
315   // Similar to State::AddArc() but updates cache store book-keeping
316   void AddArc(State *state, const Arc &arc) { state->AddArc(arc); }
317
318   // Similar to State::SetArcs() but updates cache store book-keeping; call
319   // only once.
320   void SetArcs(State *state) { state->SetArcs(); }
321
322   // Deletes all arcs.
323   void DeleteArcs(State *state) { state->DeleteArcs(); }
324
325   // Deletes some arcs.
326   void DeleteArcs(State *state, size_t n) { state->DeleteArcs(n); }
327
328   // Deletes all cached states.
329   void Clear() {
330     for (StateId s = 0; s < state_vec_.size(); ++s) {
331       State::Destroy(state_vec_[s], &state_alloc_);
332     }
333     state_vec_.clear();
334     state_list_.clear();
335   }
336
337   // Iterates over cached states (in an arbitrary order); only works if GC is
338   // enabled (o.w. avoiding state_list_ overhead).
339   bool Done() const { return iter_ == state_list_.end(); }
340
341   StateId Value() const { return *iter_; }
342
343   void Next() { ++iter_; }
344
345   void Reset() { iter_ = state_list_.begin(); }
346
347   // Deletes current state and advances to next.
348   void Delete() {
349     State::Destroy(state_vec_[*iter_], &state_alloc_);
350     state_vec_[*iter_] = nullptr;
351     state_list_.erase(iter_++);
352   }
353
354  private:
355   void CopyStates(const VectorCacheStore<State> &store) {
356     Clear();
357     state_vec_.reserve(store.state_vec_.size());
358     for (StateId s = 0; s < store.state_vec_.size(); ++s) {
359       State *state = nullptr;
360       const auto *store_state = store.state_vec_[s];
361       if (store_state) {
362         state = new (&state_alloc_) State(*store_state, arc_alloc_);
363         if (cache_gc_) state_list_.push_back(s);
364       }
365       state_vec_.push_back(state);
366     }
367   }
368
369   bool cache_gc_;                               // Supports iteration when true.
370   std::vector<State *> state_vec_;              // Vector of states (or null).
371   StateList state_list_;                        // List of states.
372   typename StateList::iterator iter_;           // State list iterator.
373   typename State::StateAllocator state_alloc_;  // For state allocation.
374   typename State::ArcAllocator arc_alloc_;      // For arc allocation.
375 };
376
377 // This class uses a hash map from state IDs to pointers to cached states.
378 template <class S>
379 class HashCacheStore {
380  public:
381   using State = S;
382   using Arc = typename State::Arc;
383   using StateId = typename Arc::StateId;
384
385   using StateMap =
386       std::unordered_map<StateId, State *, std::hash<StateId>,
387                          std::equal_to<StateId>,
388                          PoolAllocator<std::pair<const StateId, State *>>>;
389
390   // Required constructors/assignment operators.
391   explicit HashCacheStore(const CacheOptions &opts) {
392     Clear();
393     Reset();
394   }
395
396   HashCacheStore(const HashCacheStore<S> &store) {
397     CopyStates(store);
398     Reset();
399   }
400
401   ~HashCacheStore() { Clear(); }
402
403   HashCacheStore<State> &operator=(const HashCacheStore<State> &store) {
404     if (this != &store) {
405       CopyStates(store);
406       Reset();
407     }
408     return *this;
409   }
410
411   // Returns nullptr if state is not stored.
412   const State *GetState(StateId s) const {
413     const auto it = state_map_.find(s);
414     return it != state_map_.end() ? it->second : nullptr;
415   }
416
417   // Creates state if state is not stored.
418   State *GetMutableState(StateId s) {
419     auto *&state = state_map_[s];
420     if (!state) state = new (&state_alloc_) State(arc_alloc_);
421     return state;
422   }
423
424   // Similar to State::AddArc() but updates cache store book-keeping.
425   void AddArc(State *state, const Arc &arc) { state->AddArc(arc); }
426
427   // Similar to State::SetArcs() but updates internal cache size; call only
428   // once.
429   void SetArcs(State *state) { state->SetArcs(); }
430
431   // Deletes all arcs.
432   void DeleteArcs(State *state) { state->DeleteArcs(); }
433
434   // Deletes some arcs.
435   void DeleteArcs(State *state, size_t n) { state->DeleteArcs(n); }
436
437   // Deletes all cached states.
438   void Clear() {
439     for (auto it = state_map_.begin(); it != state_map_.end(); ++it) {
440       State::Destroy(it->second, &state_alloc_);
441     }
442     state_map_.clear();
443   }
444
445   // Iterates over cached states (in an arbitrary order).
446   bool Done() const { return iter_ == state_map_.end(); }
447
448   StateId Value() const { return iter_->first; }
449
450   void Next() { ++iter_; }
451
452   void Reset() { iter_ = state_map_.begin(); }
453
454   // Deletes current state and advances to next.
455   void Delete() {
456     State::Destroy(iter_->second, &state_alloc_);
457     state_map_.erase(iter_++);
458   }
459
460  private:
461   void CopyStates(const HashCacheStore<State> &store) {
462     Clear();
463     for (auto it = store.state_map_.begin(); it != store.state_map_.end();
464          ++it) {
465       state_map_[it->first] =
466           new (&state_alloc_) State(*it->second, arc_alloc_);
467     }
468   }
469
470   StateMap state_map_;                          // Map from state ID to state.
471   typename StateMap::iterator iter_;            // State map iterator.
472   typename State::StateAllocator state_alloc_;  // For state allocation.
473   typename State::ArcAllocator arc_alloc_;      // For arc allocation.
474 };
475
476 // Garbage-colllection cache stores.
477
478 // This class implements a simple garbage collection scheme when
479 // 'opts.gc_limit = 0'. In particular, the first cached state is reused for each
480 // new state so long as the reference count is zero on the to-be-reused state.
481 // Otherwise, the full underlying store is used. The caller can increment the
482 // reference count to inhibit the GC of in-use states (e.g., in an ArcIterator).
483 //
484 // The typical use case for this optimization is when a single pass over a
485 // cached
486 // FST is performed with only one-state expanded at a time.
487 template <class CacheStore>
488 class FirstCacheStore {
489  public:
490   using State = typename CacheStore::State;
491   using Arc = typename State::Arc;
492   using StateId = typename Arc::StateId;
493
494   // Required constructors/assignment operators.
495   explicit FirstCacheStore(const CacheOptions &opts)
496       : store_(opts),
497         cache_gc_(opts.gc_limit == 0),  // opts.gc ignored historically.
498         cache_first_state_id_(kNoStateId),
499         cache_first_state_(nullptr) {}
500
501   FirstCacheStore(const FirstCacheStore<CacheStore> &store)
502       : store_(store.store_),
503         cache_gc_(store.cache_gc_),
504         cache_first_state_id_(store.cache_first_state_id_),
505         cache_first_state_(store.cache_first_state_id_ != kNoStateId
506                                ? store_.GetMutableState(0)
507                                : nullptr) {}
508
509   FirstCacheStore<CacheStore> &operator=(
510       const FirstCacheStore<CacheStore> &store) {
511     if (this != &store) {
512       store_ = store.store_;
513       cache_gc_ = store.cache_gc_;
514       cache_first_state_id_ = store.cache_first_state_id_;
515       cache_first_state_ = store.cache_first_state_id_ != kNoStateId
516                                ? store_.GetMutableState(0)
517                                : nullptr;
518     }
519     return *this;
520   }
521
522   // Returns nullptr if state is not stored.
523   const State *GetState(StateId s) const {
524     // store_ state 0 may hold first cached state; the rest are shifted by 1.
525     return s == cache_first_state_id_ ? cache_first_state_
526                                       : store_.GetState(s + 1);
527   }
528
529   // Creates state if state is not stored.
530   State *GetMutableState(StateId s) {
531     // store_ state 0 used to hold first cached state; the rest are shifted by
532     // 1.
533     if (cache_first_state_id_ == s) {
534       return cache_first_state_;  // Request for first cached state.
535     }
536     if (cache_gc_) {
537       if (cache_first_state_id_ == kNoStateId) {
538         cache_first_state_id_ = s;  // Sets first cached state.
539         cache_first_state_ = store_.GetMutableState(0);
540         cache_first_state_->SetFlags(kCacheInit, kCacheInit);
541         cache_first_state_->ReserveArcs(2 * kAllocSize);
542         return cache_first_state_;
543       } else if (cache_first_state_->RefCount() == 0) {
544         cache_first_state_id_ = s;  // Updates first cached state.
545         cache_first_state_->Reset();
546         cache_first_state_->SetFlags(kCacheInit, kCacheInit);
547         return cache_first_state_;
548       } else {  // Keeps first cached state.
549         cache_first_state_->SetFlags(0, kCacheInit);  // Clears initialized bit.
550         cache_gc_ = false;                            // Disables GC.
551       }
552     }
553     auto *state = store_.GetMutableState(s + 1);
554     return state;
555   }
556
557   // Similar to State::AddArc() but updates cache store book-keeping.
558   void AddArc(State *state, const Arc &arc) { store_.AddArc(state, arc); }
559
560   // Similar to State::SetArcs() but updates internal cache size; call only
561   // once.
562   void SetArcs(State *state) { store_.SetArcs(state); }
563
564   // Deletes all arcs
565   void DeleteArcs(State *state) { store_.DeleteArcs(state); }
566
567   // Deletes some arcs
568   void DeleteArcs(State *state, size_t n) { store_.DeleteArcs(state, n); }
569
570   // Deletes all cached states
571   void Clear() {
572     store_.Clear();
573     cache_first_state_id_ = kNoStateId;
574     cache_first_state_ = nullptr;
575   }
576
577   // Iterates over cached states (in an arbitrary order). Only needed if GC is
578   // enabled.
579   bool Done() const { return store_.Done(); }
580
581   StateId Value() const {
582     // store_ state 0 may hold first cached state; rest shifted + 1.
583     const auto s = store_.Value();
584     return s ? s - 1 : cache_first_state_id_;
585   }
586
587   void Next() { store_.Next(); }
588
589   void Reset() { store_.Reset(); }
590
591   // Deletes current state and advances to next.
592   void Delete() {
593     if (Value() == cache_first_state_id_) {
594       cache_first_state_id_ = kNoStateId;
595       cache_first_state_ = nullptr;
596     }
597     store_.Delete();
598   }
599
600  private:
601   CacheStore store_;              // Underlying store.
602   bool cache_gc_;                 // GC enabled.
603   StateId cache_first_state_id_;  // First cached state ID.
604   State *cache_first_state_;      // First cached state.
605 };
606
607 // This class implements mark-sweep garbage collection on an underlying cache
608 // store. If GC is enabled, garbage collection of states is performed in a
609 // rough approximation of LRU order once when 'gc_limit' bytes is reached. The
610 // caller can increment the reference count to inhibit the GC of in-use state
611 // (e.g., in an ArcIterator). With GC enabled, the 'gc_limit' parameter allows
612 // the caller to trade-off time vs. space.
613 template <class CacheStore>
614 class GCCacheStore {
615  public:
616   using State = typename CacheStore::State;
617   using Arc = typename State::Arc;
618   using StateId = typename Arc::StateId;
619
620   // Required constructors/assignment operators.
621   explicit GCCacheStore(const CacheOptions &opts)
622       : store_(opts),
623         cache_gc_request_(opts.gc),
624         cache_limit_(opts.gc_limit > kMinCacheLimit ? opts.gc_limit
625                                                     : kMinCacheLimit),
626         cache_gc_(false),
627         cache_size_(0) {}
628
629   // Returns 0 if state is not stored.
630   const State *GetState(StateId s) const { return store_.GetState(s); }
631
632   // Creates state if state is not stored
633   State *GetMutableState(StateId s) {
634     auto *state = store_.GetMutableState(s);
635     if (cache_gc_request_ && !(state->Flags() & kCacheInit)) {
636       state->SetFlags(kCacheInit, kCacheInit);
637       cache_size_ += sizeof(State) + state->NumArcs() * sizeof(Arc);
638       // GC is enabled once an uninited state (from underlying store) is seen.
639       cache_gc_ = true;
640       if (cache_size_ > cache_limit_) GC(state, false);
641     }
642     return state;
643   }
644
645   // Similar to State::AddArc() but updates cache store book-keeping.
646   void AddArc(State *state, const Arc &arc) {
647     store_.AddArc(state, arc);
648     if (cache_gc_ && (state->Flags() & kCacheInit)) {
649       cache_size_ += sizeof(Arc);
650       if (cache_size_ > cache_limit_) GC(state, false);
651     }
652   }
653
654   // Similar to State::SetArcs() but updates internal cache size; call only
655   // once.
656   void SetArcs(State *state) {
657     store_.SetArcs(state);
658     if (cache_gc_ && (state->Flags() & kCacheInit)) {
659       cache_size_ += state->NumArcs() * sizeof(Arc);
660       if (cache_size_ > cache_limit_) GC(state, false);
661     }
662   }
663
664   // Deletes all arcs.
665   void DeleteArcs(State *state) {
666     if (cache_gc_ && (state->Flags() & kCacheInit)) {
667       cache_size_ -= state->NumArcs() * sizeof(Arc);
668     }
669     store_.DeleteArcs(state);
670   }
671
672   // Deletes some arcs.
673   void DeleteArcs(State *state, size_t n) {
674     if (cache_gc_ && (state->Flags() & kCacheInit)) {
675       cache_size_ -= n * sizeof(Arc);
676     }
677     store_.DeleteArcs(state, n);
678   }
679
680   // Deletes all cached states.
681   void Clear() {
682     store_.Clear();
683     cache_size_ = 0;
684   }
685
686   // Iterates over cached states (in an arbitrary order); only needed if GC is
687   // enabled.
688   bool Done() const { return store_.Done(); }
689
690   StateId Value() const { return store_.Value(); }
691
692   void Next() { store_.Next(); }
693
694   void Reset() { store_.Reset(); }
695
696   // Deletes current state and advances to next.
697   void Delete() {
698     if (cache_gc_) {
699       const auto *state = store_.GetState(Value());
700       if (state->Flags() & kCacheInit) {
701         cache_size_ -= sizeof(State) + state->NumArcs() * sizeof(Arc);
702       }
703     }
704     store_.Delete();
705   }
706
707   // Removes from the cache store (not referenced-counted and not the current)
708   // states that have not been accessed since the last GC until at most
709   // cache_fraction * cache_limit_ bytes are cached. If that fails to free
710   // enough, attempts to uncaching recently visited states as well. If still
711   // unable to free enough memory, then widens cache_limit_.
712   void GC(const State *current, bool free_recent, float cache_fraction = 0.666);
713
714   // Returns the current cache size in bytes or 0 if GC is disabled.
715   size_t CacheSize() const { return cache_size_; }
716
717   // Returns the cache limit in bytes.
718   size_t CacheLimit() const { return cache_limit_; }
719
720  private:
721   static constexpr size_t kMinCacheLimit = 8096;  // Minimum cache limit.
722
723   CacheStore store_;       // Underlying store.
724   bool cache_gc_request_;  // GC requested but possibly not yet enabled.
725   size_t cache_limit_;     // Number of bytes allowed before GC.
726   bool cache_gc_;          // GC enabled
727   size_t cache_size_;      // Number of bytes cached.
728 };
729
730 template <class CacheStore>
731 void GCCacheStore<CacheStore>::GC(const State *current, bool free_recent,
732                                   float cache_fraction) {
733   if (!cache_gc_) return;
734   VLOG(2) << "GCCacheStore: Enter GC: object = "
735           << "(" << this << "), free recently cached = " << free_recent
736           << ", cache size = " << cache_size_
737           << ", cache frac = " << cache_fraction
738           << ", cache limit = " << cache_limit_ << "\n";
739   size_t cache_target = cache_fraction * cache_limit_;
740   store_.Reset();
741   while (!store_.Done()) {
742     auto *state = store_.GetMutableState(store_.Value());
743     if (cache_size_ > cache_target && state->RefCount() == 0 &&
744         (free_recent || !(state->Flags() & kCacheRecent)) && state != current) {
745       if (state->Flags() & kCacheInit) {
746         size_t size = sizeof(State) + state->NumArcs() * sizeof(Arc);
747         CHECK_LE(size, cache_size_);
748         cache_size_ -= size;
749       }
750       store_.Delete();
751     } else {
752       state->SetFlags(0, kCacheRecent);
753       store_.Next();
754     }
755   }
756   if (!free_recent && cache_size_ > cache_target) {  // Recurses on recent.
757     GC(current, true, cache_fraction);
758   } else if (cache_target > 0) {  // Widens cache limit.
759     while (cache_size_ > cache_target) {
760       cache_limit_ *= 2;
761       cache_target *= 2;
762     }
763   } else if (cache_size_ > 0) {
764     FSTERROR() << "GCCacheStore:GC: Unable to free all cached states";
765   }
766   VLOG(2) << "GCCacheStore: Exit GC: object = "
767           << "(" << this << "), free recently cached = " << free_recent
768           << ", cache size = " << cache_size_
769           << ", cache frac = " << cache_fraction
770           << ", cache limit = " << cache_limit_ << "\n";
771 }
772
773 template <class CacheStore>
774 constexpr size_t GCCacheStore<CacheStore>::kMinCacheLimit;
775
776 // This class is the default cache state and store used by CacheBaseImpl.
777 // It uses VectorCacheStore for storage decorated by FirstCacheStore
778 // and GCCacheStore to do (optional) garbage collection.
779 template <class Arc>
780 class DefaultCacheStore
781     : public GCCacheStore<FirstCacheStore<VectorCacheStore<CacheState<Arc>>>> {
782  public:
783   explicit DefaultCacheStore(const CacheOptions &opts)
784       : GCCacheStore<FirstCacheStore<VectorCacheStore<CacheState<Arc>>>>(opts) {
785   }
786 };
787
788 namespace internal {
789
790 // This class is used to cache FST elements stored in states of type State
791 // (see CacheState) with the flags used to indicate what has been cached. Use
792 // HasStart(), HasFinal(), and HasArcs() to determine if cached and SetStart(),
793 // SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note that you
794 // must set the final weight even if the state is non-final to mark it as
795 // cached. The state storage method and any garbage collection policy are
796 // determined by the cache store. If the store is passed in with the options,
797 // CacheBaseImpl takes ownership.
798 template <class State,
799           class CacheStore = DefaultCacheStore<typename State::Arc>>
800 class CacheBaseImpl : public FstImpl<typename State::Arc> {
801  public:
802   using Arc = typename State::Arc;
803   using StateId = typename Arc::StateId;
804   using Weight = typename Arc::Weight;
805
806   using Store = CacheStore;
807
808   using FstImpl<Arc>::Type;
809   using FstImpl<Arc>::Properties;
810
811   explicit CacheBaseImpl(const CacheOptions &opts = CacheOptions())
812       : has_start_(false),
813         cache_start_(kNoStateId),
814         nknown_states_(0),
815         min_unexpanded_state_id_(0),
816         max_expanded_state_id_(-1),
817         cache_gc_(opts.gc),
818         cache_limit_(opts.gc_limit),
819         cache_store_(new CacheStore(opts)),
820         new_cache_store_(true),
821         own_cache_store_(true) {}
822
823   explicit CacheBaseImpl(const CacheImplOptions<CacheStore> &opts)
824       : has_start_(false),
825         cache_start_(kNoStateId),
826         nknown_states_(0),
827         min_unexpanded_state_id_(0),
828         max_expanded_state_id_(-1),
829         cache_gc_(opts.gc),
830         cache_limit_(opts.gc_limit),
831         cache_store_(opts.store ? opts.store : new CacheStore(CacheOptions(
832                                                    opts.gc, opts.gc_limit))),
833         new_cache_store_(!opts.store),
834         own_cache_store_(opts.store ? opts.own_store : true) {}
835
836   // Preserve gc parameters. If preserve_cache is true, also preserves
837   // cache data.
838   CacheBaseImpl(const CacheBaseImpl<State, CacheStore> &impl,
839                 bool preserve_cache = false)
840       : FstImpl<Arc>(),
841         has_start_(false),
842         cache_start_(kNoStateId),
843         nknown_states_(0),
844         min_unexpanded_state_id_(0),
845         max_expanded_state_id_(-1),
846         cache_gc_(impl.cache_gc_),
847         cache_limit_(impl.cache_limit_),
848         cache_store_(new CacheStore(CacheOptions(cache_gc_, cache_limit_))),
849         new_cache_store_(impl.new_cache_store_ || !preserve_cache),
850         own_cache_store_(true) {
851     if (preserve_cache) {
852       *cache_store_ = *impl.cache_store_;
853       has_start_ = impl.has_start_;
854       cache_start_ = impl.cache_start_;
855       nknown_states_ = impl.nknown_states_;
856       expanded_states_ = impl.expanded_states_;
857       min_unexpanded_state_id_ = impl.min_unexpanded_state_id_;
858       max_expanded_state_id_ = impl.max_expanded_state_id_;
859     }
860   }
861
862   ~CacheBaseImpl() override { if (own_cache_store_) delete cache_store_; }
863
864   void SetStart(StateId s) {
865     cache_start_ = s;
866     has_start_ = true;
867     if (s >= nknown_states_) nknown_states_ = s + 1;
868   }
869
870   void SetFinal(StateId s, Weight weight) {
871     auto *state = cache_store_->GetMutableState(s);
872     state->SetFinal(std::move(weight));
873     static constexpr auto flags = kCacheFinal | kCacheRecent;
874     state->SetFlags(flags, flags);
875   }
876
877 // Disabled to ensure PushArc not AddArc is used in existing code
878 // TODO(sorenj): re-enable for backing store
879 #if 0
880   // AddArc adds a single arc to a state and does incremental cache
881   // book-keeping. For efficiency, prefer PushArc and SetArcs below
882   // when possible.
883   void AddArc(StateId s, const Arc &arc) {
884     auto *state = cache_store_->GetMutableState(s);
885     cache_store_->AddArc(state, arc);
886     if (arc.nextstate >= nknown_states_)
887       nknown_states_ = arc.nextstate + 1;
888     SetExpandedState(s);
889     static constexpr auto flags = kCacheArcs | kCacheRecent;
890     state->SetFlags(flags, flags);
891   }
892 #endif
893
894   // Adds a single arc to a state but delays cache book-keeping. SetArcs must
895   // be called when all PushArc calls at a state are complete. Do not mix with
896   // calls to AddArc.
897   void PushArc(StateId s, const Arc &arc) {
898     auto *state = cache_store_->GetMutableState(s);
899     state->PushArc(arc);
900   }
901
902   // Marks arcs of a state as cached and does cache book-keeping after all
903   // calls to PushArc have been completed. Do not mix with calls to AddArc.
904   void SetArcs(StateId s) {
905     auto *state = cache_store_->GetMutableState(s);
906     cache_store_->SetArcs(state);
907     const auto narcs = state->NumArcs();
908     for (size_t a = 0; a < narcs; ++a) {
909       const auto &arc = state->GetArc(a);
910       if (arc.nextstate >= nknown_states_) nknown_states_ = arc.nextstate + 1;
911     }
912     SetExpandedState(s);
913     static constexpr auto flags = kCacheArcs | kCacheRecent;
914     state->SetFlags(flags, flags);
915   }
916
917   void ReserveArcs(StateId s, size_t n) {
918     auto *state = cache_store_->GetMutableState(s);
919     state->ReserveArcs(n);
920   }
921
922   void DeleteArcs(StateId s) {
923     auto *state = cache_store_->GetMutableState(s);
924     cache_store_->DeleteArcs(state);
925   }
926
927   void DeleteArcs(StateId s, size_t n) {
928     auto *state = cache_store_->GetMutableState(s);
929     cache_store_->DeleteArcs(state, n);
930   }
931
932   void Clear() {
933     nknown_states_ = 0;
934     min_unexpanded_state_id_ = 0;
935     max_expanded_state_id_ = -1;
936     has_start_ = false;
937     cache_start_ = kNoStateId;
938     cache_store_->Clear();
939   }
940
941   // Is the start state cached?
942   bool HasStart() const {
943     if (!has_start_ && Properties(kError)) has_start_ = true;
944     return has_start_;
945   }
946
947   // Is the final weight of the state cached?
948   bool HasFinal(StateId s) const {
949     const auto *state = cache_store_->GetState(s);
950     if (state && state->Flags() & kCacheFinal) {
951       state->SetFlags(kCacheRecent, kCacheRecent);
952       return true;
953     } else {
954       return false;
955     }
956   }
957
958   // Are arcs of the state cached?
959   bool HasArcs(StateId s) const {
960     const auto *state = cache_store_->GetState(s);
961     if (state && state->Flags() & kCacheArcs) {
962       state->SetFlags(kCacheRecent, kCacheRecent);
963       return true;
964     } else {
965       return false;
966     }
967   }
968
969   StateId Start() const { return cache_start_; }
970
971   Weight Final(StateId s) const {
972     const auto *state = cache_store_->GetState(s);
973     return state->Final();
974   }
975
976   size_t NumArcs(StateId s) const {
977     const auto *state = cache_store_->GetState(s);
978     return state->NumArcs();
979   }
980
981   size_t NumInputEpsilons(StateId s) const {
982     const auto *state = cache_store_->GetState(s);
983     return state->NumInputEpsilons();
984   }
985
986   size_t NumOutputEpsilons(StateId s) const {
987     const auto *state = cache_store_->GetState(s);
988     return state->NumOutputEpsilons();
989   }
990
991   // Provides information needed for generic arc iterator.
992   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
993     const auto *state = cache_store_->GetState(s);
994     data->base = nullptr;
995     data->narcs = state->NumArcs();
996     data->arcs = state->Arcs();
997     data->ref_count = state->MutableRefCount();
998     state->IncrRefCount();
999   }
1000
1001   // Number of known states.
1002   StateId NumKnownStates() const { return nknown_states_; }
1003
1004   // Updates number of known states, taking into account the passed state ID.
1005   void UpdateNumKnownStates(StateId s) {
1006     if (s >= nknown_states_) nknown_states_ = s + 1;
1007   }
1008
1009   // Finds the mininum never-expanded state ID.
1010   StateId MinUnexpandedState() const {
1011     while (min_unexpanded_state_id_ <= max_expanded_state_id_ &&
1012            ExpandedState(min_unexpanded_state_id_)) {
1013       ++min_unexpanded_state_id_;
1014     }
1015     return min_unexpanded_state_id_;
1016   }
1017
1018   // Returns maximum ever-expanded state ID.
1019   StateId MaxExpandedState() const { return max_expanded_state_id_; }
1020
1021   void SetExpandedState(StateId s) {
1022     if (s > max_expanded_state_id_) max_expanded_state_id_ = s;
1023     if (s < min_unexpanded_state_id_) return;
1024     if (s == min_unexpanded_state_id_) ++min_unexpanded_state_id_;
1025     if (cache_gc_ || cache_limit_ == 0) {
1026       while (expanded_states_.size() <= s) expanded_states_.push_back(false);
1027       expanded_states_[s] = true;
1028     }
1029   }
1030
1031   bool ExpandedState(StateId s) const {
1032     if (cache_gc_ || cache_limit_ == 0) {
1033       return expanded_states_[s];
1034     } else if (new_cache_store_) {
1035       return cache_store_->GetState(s) != nullptr;
1036     } else {
1037       // If the cache was not created by this class, then the cached state needs
1038       // to be inspected to update nknown_states_.
1039       return false;
1040     }
1041   }
1042
1043   const CacheStore *GetCacheStore() const { return cache_store_; }
1044
1045   CacheStore *GetCacheStore() { return cache_store_; }
1046
1047   // Caching on/off switch, limit and size accessors.
1048
1049   bool GetCacheGc() const { return cache_gc_; }
1050
1051   size_t GetCacheLimit() const { return cache_limit_; }
1052
1053  private:
1054   mutable bool has_start_;                   // Is the start state cached?
1055   StateId cache_start_;                      // ID of start state.
1056   StateId nknown_states_;                    // Number of known states.
1057   std::vector<bool> expanded_states_;        // States that have been expanded.
1058   mutable StateId min_unexpanded_state_id_;  // Minimum never-expanded state ID
1059   mutable StateId max_expanded_state_id_;    // Maximum ever-expanded state ID
1060   bool cache_gc_;                            // GC enabled.
1061   size_t cache_limit_;       // Number of bytes allowed before GC.
1062   CacheStore *cache_store_;  // The store of cached states.
1063   bool new_cache_store_;     // Was the store was created by class?
1064   bool own_cache_store_;     // Is the store owned by class?
1065
1066   CacheBaseImpl &operator=(const CacheBaseImpl &impl) = delete;
1067 };
1068
1069 // A CacheBaseImpl with the default cache state type.
1070 template <class Arc>
1071 class CacheImpl : public CacheBaseImpl<CacheState<Arc>> {
1072  public:
1073   using State = CacheState<Arc>;
1074
1075   CacheImpl() {}
1076
1077   explicit CacheImpl(const CacheOptions &opts)
1078       : CacheBaseImpl<CacheState<Arc>>(opts) {}
1079
1080   CacheImpl(const CacheImpl<Arc> &impl, bool preserve_cache = false)
1081       : CacheBaseImpl<State>(impl, preserve_cache) {}
1082
1083  private:
1084   CacheImpl &operator=(const CacheImpl &impl) = delete;
1085 };
1086
1087 }  // namespace internal
1088
1089 // Use this to make a state iterator for a CacheBaseImpl-derived FST, which must
1090 // have Arc and Store types defined. Note this iterator only returns those
1091 // states reachable from the initial state, so consider implementing a
1092 // class-specific one.
1093 //
1094 // This class may be derived from.
1095 template <class FST>
1096 class CacheStateIterator : public StateIteratorBase<typename FST::Arc> {
1097  public:
1098   using Arc = typename FST::Arc;
1099   using StateId = typename Arc::StateId;
1100   using Weight = typename Arc::Weight;
1101
1102   using Store = typename FST::Store;
1103   using State = typename Store::State;
1104   using Impl = internal::CacheBaseImpl<State, Store>;
1105
1106   CacheStateIterator(const FST &fst, Impl *impl)
1107       : fst_(fst), impl_(impl), s_(0) {
1108     fst_.Start();  // Forces start state.
1109   }
1110
1111   bool Done() const final {
1112     if (s_ < impl_->NumKnownStates()) return false;
1113     for (StateId u = impl_->MinUnexpandedState(); u < impl_->NumKnownStates();
1114          u = impl_->MinUnexpandedState()) {
1115       // Forces state expansion.
1116       ArcIterator<FST> aiter(fst_, u);
1117       aiter.SetFlags(kArcValueFlags, kArcValueFlags | kArcNoCache);
1118       for (; !aiter.Done(); aiter.Next()) {
1119         impl_->UpdateNumKnownStates(aiter.Value().nextstate);
1120       }
1121       impl_->SetExpandedState(u);
1122       if (s_ < impl_->NumKnownStates()) return false;
1123     }
1124     return true;
1125   }
1126
1127   StateId Value() const final { return s_; }
1128
1129   void Next() final { ++s_; }
1130
1131   void Reset() final { s_ = 0; }
1132
1133  private:
1134   const FST &fst_;
1135   Impl *impl_;
1136   StateId s_;
1137 };
1138
1139 // Used to make an arc iterator for a CacheBaseImpl-derived FST, which must
1140 // have Arc and State types defined.
1141 template <class FST>
1142 class CacheArcIterator {
1143  public:
1144   using Arc = typename FST::Arc;
1145   using StateId = typename Arc::StateId;
1146   using Weight = typename Arc::Weight;
1147
1148   using Store = typename FST::Store;
1149   using State = typename Store::State;
1150   using Impl = internal::CacheBaseImpl<State, Store>;
1151
1152   CacheArcIterator(Impl *impl, StateId s) : i_(0) {
1153     state_ = impl->GetCacheStore()->GetMutableState(s);
1154     state_->IncrRefCount();
1155   }
1156
1157   ~CacheArcIterator() { state_->DecrRefCount(); }
1158
1159   bool Done() const { return i_ >= state_->NumArcs(); }
1160
1161   const Arc &Value() const { return state_->GetArc(i_); }
1162
1163   void Next() { ++i_; }
1164
1165   size_t Position() const { return i_; }
1166
1167   void Reset() { i_ = 0; }
1168
1169   void Seek(size_t a) { i_ = a; }
1170
1171   constexpr uint32 Flags() const { return kArcValueFlags; }
1172
1173   void SetFlags(uint32 flags, uint32 mask) {}
1174
1175  private:
1176   const State *state_;
1177   size_t i_;
1178
1179   CacheArcIterator(const CacheArcIterator &) = delete;
1180   CacheArcIterator &operator=(const CacheArcIterator &) = delete;
1181 };
1182
1183 // Use this to make a mutable arc iterator for a CacheBaseImpl-derived FST,
1184 // which must have types Arc and Store defined.
1185 template <class FST>
1186 class CacheMutableArcIterator
1187     : public MutableArcIteratorBase<typename FST::Arc> {
1188  public:
1189   using Arc = typename FST::Arc;
1190   using StateId = typename Arc::StateId;
1191   using Weight = typename Arc::Weight;
1192
1193   using Store = typename FST::Store;
1194   using State = typename Store::State;
1195   using Impl = internal::CacheBaseImpl<State, Store>;
1196
1197   // User must call MutateCheck() in the constructor.
1198   CacheMutableArcIterator(Impl *impl, StateId s) : i_(0), s_(s), impl_(impl) {
1199     state_ = impl_->GetCacheStore()->GetMutableState(s_);
1200     state_->IncrRefCount();
1201   }
1202
1203   ~CacheMutableArcIterator() override { state_->DecrRefCount(); }
1204
1205   bool Done() const final { return i_ >= state_->NumArcs(); }
1206
1207   const Arc &Value() const final { return state_->GetArc(i_); }
1208
1209   void Next() final { ++i_; }
1210
1211   size_t Position() const final { return i_; }
1212
1213   void Reset() final { i_ = 0; }
1214
1215   void Seek(size_t a) final { i_ = a; }
1216
1217   void SetValue(const Arc &arc) final { state_->SetArc(arc, i_); }
1218
1219   uint32 Flags() const final { return kArcValueFlags; }
1220
1221   void SetFlags(uint32, uint32) final {}
1222
1223  private:
1224   size_t i_;
1225   StateId s_;
1226   Impl *impl_;
1227   State *state_;
1228
1229   CacheMutableArcIterator(const CacheMutableArcIterator &) = delete;
1230   CacheMutableArcIterator &operator=(const CacheMutableArcIterator &) = delete;
1231 };
1232
1233 // Wrap existing CacheStore implementation to use with ExpanderFst.
1234 template <class CacheStore>
1235 class ExpanderCacheStore {
1236  public:
1237   using State = typename CacheStore::State;
1238   using Arc = typename CacheStore::Arc;
1239   using StateId = typename Arc::StateId;
1240   using Weight = typename Arc::Weight;
1241
1242   explicit ExpanderCacheStore(const CacheOptions &opts = CacheOptions())
1243       : store_(opts) {}
1244
1245   template <class Expander>
1246   State *FindOrExpand(Expander &expander, StateId s) {  // NOLINT
1247     auto *state = store_.GetMutableState(s);
1248     if (state->Flags()) {
1249       state->SetFlags(kCacheRecent, kCacheRecent);
1250     } else {
1251       StateBuilder builder(state);
1252       expander.Expand(s, &builder);
1253       state->SetFlags(kCacheFlags, kCacheFlags);
1254       store_.SetArcs(state);
1255     }
1256     return state;
1257   }
1258
1259  private:
1260   CacheStore store_;
1261
1262   struct StateBuilder {
1263     State *state;
1264
1265     explicit StateBuilder(State *state_) : state(state_) {}
1266
1267     void AddArc(const Arc &arc) { state->PushArc(arc); }
1268
1269     void SetFinal(Weight weight) { state->SetFinal(std::move(weight)); }
1270   };
1271 };
1272
1273 }  // namespace fst
1274
1275 #endif  // FST_LIB_CACHE_H_