1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // An FST implementation that caches FST elements of a delayed computation.
6 #ifndef FST_LIB_CACHE_H_
7 #define FST_LIB_CACHE_H_
10 #include <unordered_map>
11 using std::unordered_map;
12 using std::unordered_multimap;
18 #include <fst/vector-fst.h>
21 DECLARE_bool(fst_default_cache_gc);
22 DECLARE_int64(fst_default_cache_gc_limit);
26 // Options for controlling caching behavior; higher level than CacheImplOptions.
28 bool gc; // Enables GC.
29 size_t gc_limit; // Number of bytes allowed before GC.
31 explicit CacheOptions(bool gc = FLAGS_fst_default_cache_gc,
32 size_t gc_limit = FLAGS_fst_default_cache_gc_limit)
33 : gc(gc), gc_limit(gc_limit) {}
36 // Options for controlling caching behavior, at a lower level than
37 // CacheOptions; templated on the cache store and allows passing the store.
38 template <class CacheStore>
39 struct CacheImplOptions {
40 bool gc; // Enables GC.
41 size_t gc_limit; // Number of bytes allowed before GC.
42 CacheStore *store; // Cache store.
43 bool own_store; // Should CacheImpl takes ownership of the store?
45 explicit CacheImplOptions(bool gc = FLAGS_fst_default_cache_gc,
46 size_t gc_limit = FLAGS_fst_default_cache_gc_limit,
47 CacheStore *store = nullptr)
48 : gc(gc), gc_limit(gc_limit), store(store), own_store(true) {}
50 explicit CacheImplOptions(const CacheOptions &opts)
51 : gc(opts.gc), gc_limit(opts.gc_limit), store(nullptr), own_store(true) {}
55 constexpr uint32 kCacheFinal = 0x0001; // Final weight has been cached.
56 constexpr uint32 kCacheArcs = 0x0002; // Arcs have been cached.
57 constexpr uint32 kCacheInit = 0x0004; // Initialized by GC.
58 constexpr uint32 kCacheRecent = 0x0008; // Visited since GC.
59 constexpr uint32 kCacheFlags =
60 kCacheFinal | kCacheArcs | kCacheInit | kCacheRecent;
62 // Cache state, with arcs stored in a per-state std::vector.
63 template <class A, class M = PoolAllocator<A>>
67 using Label = typename Arc::Label;
68 using StateId = typename Arc::StateId;
69 using Weight = typename Arc::Weight;
71 using ArcAllocator = M;
72 using StateAllocator =
73 typename ArcAllocator::template rebind<CacheState<A, M>>::other;
75 // Provides STL allocator for arcs.
76 explicit CacheState(const ArcAllocator &alloc)
77 : final_(Weight::Zero()),
84 CacheState(const CacheState<A> &state, const ArcAllocator &alloc)
85 : final_(state.Final()),
86 niepsilons_(state.NumInputEpsilons()),
87 noepsilons_(state.NumOutputEpsilons()),
88 arcs_(state.arcs_.begin(), state.arcs_.end(), alloc),
89 flags_(state.Flags()),
93 final_ = Weight::Zero();
101 Weight Final() const { return final_; }
103 size_t NumInputEpsilons() const { return niepsilons_; }
105 size_t NumOutputEpsilons() const { return noepsilons_; }
107 size_t NumArcs() const { return arcs_.size(); }
109 const Arc &GetArc(size_t n) const { return arcs_[n]; }
111 // Used by the ArcIterator<Fst<Arc>> efficient implementation.
112 const Arc *Arcs() const { return !arcs_.empty() ? &arcs_[0] : nullptr; }
114 // Accesses flags; used by the caller.
115 uint32 Flags() const { return flags_; }
117 // Accesses ref count; used by the caller.
118 int RefCount() const { return ref_count_; }
120 void SetFinal(Weight weight) { final_ = std::move(weight); }
122 void ReserveArcs(size_t n) { arcs_.reserve(n); }
124 // Adds one arc at a time with all needed book-keeping; use PushArc and
125 // SetArcs for a more efficient alternative.
126 void AddArc(const Arc &arc) {
127 arcs_.push_back(arc);
128 if (arc.ilabel == 0) ++niepsilons_;
129 if (arc.olabel == 0) ++noepsilons_;
132 // Adds one arc at a time with delayed book-keeping; finalize with SetArcs().
133 void PushArc(const Arc &arc) { arcs_.push_back(arc); }
135 // Finalizes arcs book-keeping; call only once.
137 for (const auto &arc : arcs_) {
138 if (arc.ilabel == 0) ++niepsilons_;
139 if (arc.olabel == 0) ++noepsilons_;
144 void SetArc(const Arc &arc, size_t n) {
145 if (arcs_[n].ilabel == 0) --niepsilons_;
146 if (arcs_[n].olabel == 0) --noepsilons_;
147 if (arc.ilabel == 0) ++niepsilons_;
148 if (arc.olabel == 0) ++noepsilons_;
159 void DeleteArcs(size_t n) {
160 for (size_t i = 0; i < n; ++i) {
161 if (arcs_.back().ilabel == 0) --niepsilons_;
162 if (arcs_.back().olabel == 0) --noepsilons_;
167 // Sets status flags; used by the caller.
168 void SetFlags(uint32 flags, uint32 mask) const {
173 // Mutates reference counts; used by the caller.
175 int IncrRefCount() const { return ++ref_count_; }
177 int DecrRefCount() const { return --ref_count_; }
179 // Used by the ArcIterator<Fst<Arc>> efficient implementation.
180 int *MutableRefCount() const { return &ref_count_; }
182 // Used for state class allocation.
183 void *operator new(size_t size, StateAllocator *alloc) {
184 return alloc->allocate(1);
187 // For state destruction and memory freeing.
188 static void Destroy(CacheState<Arc> *state, StateAllocator *alloc) {
190 state->~CacheState<Arc>();
191 alloc->deallocate(state, 1);
196 Weight final_; // Final weight.
197 size_t niepsilons_; // # of input epsilons.
198 size_t noepsilons_; // # of output epsilons.
199 std::vector<Arc, ArcAllocator> arcs_; // Arcs representation.
200 mutable uint32 flags_;
201 mutable int ref_count_; // If 0, available for GC.
204 // Cache store, allocating and storing states, providing a mapping from state
205 // IDs to cached states, and an iterator over these states. The state template
206 // argument must implement the CacheState interface. The state for a StateId s
207 // is constructed when requested by GetMutableState(s) if it is not yet stored.
208 // Initially, a state has a reference count of zero, but the user may increment
209 // or decrement this to control the time of destruction. In particular, a state
210 // is destroyed when:
212 // 1. This instance is destroyed, or
213 // 2. Clear() or Delete() is called, or
214 // 3. Possibly (implementation-dependently) when:
215 // - Garbage collection is enabled (as defined by opts.gc),
216 // - The cache store size exceeds the limits (as defined by opts.gc_limits),
217 // - The state's reference count is zero, and
218 // - The state is not the most recently requested state.
220 // template <class S>
221 // class CacheStore {
224 // using Arc = typename State::Arc;
225 // using StateId = typename Arc::StateId;
227 // // Required constructors/assignment operators.
228 // explicit CacheStore(const CacheOptions &opts);
230 // // Returns nullptr if state is not stored.
231 // const State *GetState(StateId s);
233 // // Creates state if state is not stored.
234 // State *GetMutableState(StateId s);
236 // // Similar to State::AddArc() but updates cache store book-keeping.
237 // void AddArc(State *state, const Arc &arc);
239 // // Similar to State::SetArcs() but updates cache store book-keeping; call
241 // void SetArcs(State *state);
243 // // Similar to State::DeleteArcs() but updates cache store book-keeping.
245 // void DeleteArcs(State *state);
247 // void DeleteArcs(State *state, size_t n);
249 // // Deletes all cached states.
252 // // Iterates over cached states (in an arbitrary order); only needed if
253 // // opts.gc is true.
254 // bool Done() const; // End of iteration.
255 // StateId Value() const; // Current state.
256 // void Next(); // Advances to next state (when !Done).
257 // void Reset(); // Returns to initial condition.
258 // void Delete(); // Deletes current state and advances to next.
261 // Container cache stores.
263 // This class uses a vector of pointers to states to store cached states.
265 class VectorCacheStore {
268 using Arc = typename State::Arc;
269 using StateId = typename Arc::StateId;
270 using StateList = std::list<StateId, PoolAllocator<StateId>>;
272 // Required constructors/assignment operators.
273 explicit VectorCacheStore(const CacheOptions &opts) : cache_gc_(opts.gc) {
278 VectorCacheStore(const VectorCacheStore<S> &store)
279 : cache_gc_(store.cache_gc_) {
284 ~VectorCacheStore() { Clear(); }
286 VectorCacheStore<State> &operator=(const VectorCacheStore<State> &store) {
287 if (this != &store) {
294 // Returns nullptr if state is not stored.
295 const State *GetState(StateId s) const {
296 return s < state_vec_.size() ? state_vec_[s] : nullptr;
299 // Creates state if state is not stored.
300 State *GetMutableState(StateId s) {
301 State *state = nullptr;
302 if (s >= state_vec_.size()) {
303 state_vec_.resize(s + 1, nullptr);
305 state = state_vec_[s];
308 state = new (&state_alloc_) State(arc_alloc_);
309 state_vec_[s] = state;
310 if (cache_gc_) state_list_.push_back(s);
315 // Similar to State::AddArc() but updates cache store book-keeping
316 void AddArc(State *state, const Arc &arc) { state->AddArc(arc); }
318 // Similar to State::SetArcs() but updates cache store book-keeping; call
320 void SetArcs(State *state) { state->SetArcs(); }
323 void DeleteArcs(State *state) { state->DeleteArcs(); }
325 // Deletes some arcs.
326 void DeleteArcs(State *state, size_t n) { state->DeleteArcs(n); }
328 // Deletes all cached states.
330 for (StateId s = 0; s < state_vec_.size(); ++s) {
331 State::Destroy(state_vec_[s], &state_alloc_);
337 // Iterates over cached states (in an arbitrary order); only works if GC is
338 // enabled (o.w. avoiding state_list_ overhead).
339 bool Done() const { return iter_ == state_list_.end(); }
341 StateId Value() const { return *iter_; }
343 void Next() { ++iter_; }
345 void Reset() { iter_ = state_list_.begin(); }
347 // Deletes current state and advances to next.
349 State::Destroy(state_vec_[*iter_], &state_alloc_);
350 state_vec_[*iter_] = nullptr;
351 state_list_.erase(iter_++);
355 void CopyStates(const VectorCacheStore<State> &store) {
357 state_vec_.reserve(store.state_vec_.size());
358 for (StateId s = 0; s < store.state_vec_.size(); ++s) {
359 State *state = nullptr;
360 const auto *store_state = store.state_vec_[s];
362 state = new (&state_alloc_) State(*store_state, arc_alloc_);
363 if (cache_gc_) state_list_.push_back(s);
365 state_vec_.push_back(state);
369 bool cache_gc_; // Supports iteration when true.
370 std::vector<State *> state_vec_; // Vector of states (or null).
371 StateList state_list_; // List of states.
372 typename StateList::iterator iter_; // State list iterator.
373 typename State::StateAllocator state_alloc_; // For state allocation.
374 typename State::ArcAllocator arc_alloc_; // For arc allocation.
377 // This class uses a hash map from state IDs to pointers to cached states.
379 class HashCacheStore {
382 using Arc = typename State::Arc;
383 using StateId = typename Arc::StateId;
386 std::unordered_map<StateId, State *, std::hash<StateId>,
387 std::equal_to<StateId>,
388 PoolAllocator<std::pair<const StateId, State *>>>;
390 // Required constructors/assignment operators.
391 explicit HashCacheStore(const CacheOptions &opts) {
396 HashCacheStore(const HashCacheStore<S> &store) {
401 ~HashCacheStore() { Clear(); }
403 HashCacheStore<State> &operator=(const HashCacheStore<State> &store) {
404 if (this != &store) {
411 // Returns nullptr if state is not stored.
412 const State *GetState(StateId s) const {
413 const auto it = state_map_.find(s);
414 return it != state_map_.end() ? it->second : nullptr;
417 // Creates state if state is not stored.
418 State *GetMutableState(StateId s) {
419 auto *&state = state_map_[s];
420 if (!state) state = new (&state_alloc_) State(arc_alloc_);
424 // Similar to State::AddArc() but updates cache store book-keeping.
425 void AddArc(State *state, const Arc &arc) { state->AddArc(arc); }
427 // Similar to State::SetArcs() but updates internal cache size; call only
429 void SetArcs(State *state) { state->SetArcs(); }
432 void DeleteArcs(State *state) { state->DeleteArcs(); }
434 // Deletes some arcs.
435 void DeleteArcs(State *state, size_t n) { state->DeleteArcs(n); }
437 // Deletes all cached states.
439 for (auto it = state_map_.begin(); it != state_map_.end(); ++it) {
440 State::Destroy(it->second, &state_alloc_);
445 // Iterates over cached states (in an arbitrary order).
446 bool Done() const { return iter_ == state_map_.end(); }
448 StateId Value() const { return iter_->first; }
450 void Next() { ++iter_; }
452 void Reset() { iter_ = state_map_.begin(); }
454 // Deletes current state and advances to next.
456 State::Destroy(iter_->second, &state_alloc_);
457 state_map_.erase(iter_++);
461 void CopyStates(const HashCacheStore<State> &store) {
463 for (auto it = store.state_map_.begin(); it != store.state_map_.end();
465 state_map_[it->first] =
466 new (&state_alloc_) State(*it->second, arc_alloc_);
470 StateMap state_map_; // Map from state ID to state.
471 typename StateMap::iterator iter_; // State map iterator.
472 typename State::StateAllocator state_alloc_; // For state allocation.
473 typename State::ArcAllocator arc_alloc_; // For arc allocation.
476 // Garbage-colllection cache stores.
478 // This class implements a simple garbage collection scheme when
479 // 'opts.gc_limit = 0'. In particular, the first cached state is reused for each
480 // new state so long as the reference count is zero on the to-be-reused state.
481 // Otherwise, the full underlying store is used. The caller can increment the
482 // reference count to inhibit the GC of in-use states (e.g., in an ArcIterator).
484 // The typical use case for this optimization is when a single pass over a
486 // FST is performed with only one-state expanded at a time.
487 template <class CacheStore>
488 class FirstCacheStore {
490 using State = typename CacheStore::State;
491 using Arc = typename State::Arc;
492 using StateId = typename Arc::StateId;
494 // Required constructors/assignment operators.
495 explicit FirstCacheStore(const CacheOptions &opts)
497 cache_gc_(opts.gc_limit == 0), // opts.gc ignored historically.
498 cache_first_state_id_(kNoStateId),
499 cache_first_state_(nullptr) {}
501 FirstCacheStore(const FirstCacheStore<CacheStore> &store)
502 : store_(store.store_),
503 cache_gc_(store.cache_gc_),
504 cache_first_state_id_(store.cache_first_state_id_),
505 cache_first_state_(store.cache_first_state_id_ != kNoStateId
506 ? store_.GetMutableState(0)
509 FirstCacheStore<CacheStore> &operator=(
510 const FirstCacheStore<CacheStore> &store) {
511 if (this != &store) {
512 store_ = store.store_;
513 cache_gc_ = store.cache_gc_;
514 cache_first_state_id_ = store.cache_first_state_id_;
515 cache_first_state_ = store.cache_first_state_id_ != kNoStateId
516 ? store_.GetMutableState(0)
522 // Returns nullptr if state is not stored.
523 const State *GetState(StateId s) const {
524 // store_ state 0 may hold first cached state; the rest are shifted by 1.
525 return s == cache_first_state_id_ ? cache_first_state_
526 : store_.GetState(s + 1);
529 // Creates state if state is not stored.
530 State *GetMutableState(StateId s) {
531 // store_ state 0 used to hold first cached state; the rest are shifted by
533 if (cache_first_state_id_ == s) {
534 return cache_first_state_; // Request for first cached state.
537 if (cache_first_state_id_ == kNoStateId) {
538 cache_first_state_id_ = s; // Sets first cached state.
539 cache_first_state_ = store_.GetMutableState(0);
540 cache_first_state_->SetFlags(kCacheInit, kCacheInit);
541 cache_first_state_->ReserveArcs(2 * kAllocSize);
542 return cache_first_state_;
543 } else if (cache_first_state_->RefCount() == 0) {
544 cache_first_state_id_ = s; // Updates first cached state.
545 cache_first_state_->Reset();
546 cache_first_state_->SetFlags(kCacheInit, kCacheInit);
547 return cache_first_state_;
548 } else { // Keeps first cached state.
549 cache_first_state_->SetFlags(0, kCacheInit); // Clears initialized bit.
550 cache_gc_ = false; // Disables GC.
553 auto *state = store_.GetMutableState(s + 1);
557 // Similar to State::AddArc() but updates cache store book-keeping.
558 void AddArc(State *state, const Arc &arc) { store_.AddArc(state, arc); }
560 // Similar to State::SetArcs() but updates internal cache size; call only
562 void SetArcs(State *state) { store_.SetArcs(state); }
565 void DeleteArcs(State *state) { store_.DeleteArcs(state); }
568 void DeleteArcs(State *state, size_t n) { store_.DeleteArcs(state, n); }
570 // Deletes all cached states
573 cache_first_state_id_ = kNoStateId;
574 cache_first_state_ = nullptr;
577 // Iterates over cached states (in an arbitrary order). Only needed if GC is
579 bool Done() const { return store_.Done(); }
581 StateId Value() const {
582 // store_ state 0 may hold first cached state; rest shifted + 1.
583 const auto s = store_.Value();
584 return s ? s - 1 : cache_first_state_id_;
587 void Next() { store_.Next(); }
589 void Reset() { store_.Reset(); }
591 // Deletes current state and advances to next.
593 if (Value() == cache_first_state_id_) {
594 cache_first_state_id_ = kNoStateId;
595 cache_first_state_ = nullptr;
601 CacheStore store_; // Underlying store.
602 bool cache_gc_; // GC enabled.
603 StateId cache_first_state_id_; // First cached state ID.
604 State *cache_first_state_; // First cached state.
607 // This class implements mark-sweep garbage collection on an underlying cache
608 // store. If GC is enabled, garbage collection of states is performed in a
609 // rough approximation of LRU order once when 'gc_limit' bytes is reached. The
610 // caller can increment the reference count to inhibit the GC of in-use state
611 // (e.g., in an ArcIterator). With GC enabled, the 'gc_limit' parameter allows
612 // the caller to trade-off time vs. space.
613 template <class CacheStore>
616 using State = typename CacheStore::State;
617 using Arc = typename State::Arc;
618 using StateId = typename Arc::StateId;
620 // Required constructors/assignment operators.
621 explicit GCCacheStore(const CacheOptions &opts)
623 cache_gc_request_(opts.gc),
624 cache_limit_(opts.gc_limit > kMinCacheLimit ? opts.gc_limit
629 // Returns 0 if state is not stored.
630 const State *GetState(StateId s) const { return store_.GetState(s); }
632 // Creates state if state is not stored
633 State *GetMutableState(StateId s) {
634 auto *state = store_.GetMutableState(s);
635 if (cache_gc_request_ && !(state->Flags() & kCacheInit)) {
636 state->SetFlags(kCacheInit, kCacheInit);
637 cache_size_ += sizeof(State) + state->NumArcs() * sizeof(Arc);
638 // GC is enabled once an uninited state (from underlying store) is seen.
640 if (cache_size_ > cache_limit_) GC(state, false);
645 // Similar to State::AddArc() but updates cache store book-keeping.
646 void AddArc(State *state, const Arc &arc) {
647 store_.AddArc(state, arc);
648 if (cache_gc_ && (state->Flags() & kCacheInit)) {
649 cache_size_ += sizeof(Arc);
650 if (cache_size_ > cache_limit_) GC(state, false);
654 // Similar to State::SetArcs() but updates internal cache size; call only
656 void SetArcs(State *state) {
657 store_.SetArcs(state);
658 if (cache_gc_ && (state->Flags() & kCacheInit)) {
659 cache_size_ += state->NumArcs() * sizeof(Arc);
660 if (cache_size_ > cache_limit_) GC(state, false);
665 void DeleteArcs(State *state) {
666 if (cache_gc_ && (state->Flags() & kCacheInit)) {
667 cache_size_ -= state->NumArcs() * sizeof(Arc);
669 store_.DeleteArcs(state);
672 // Deletes some arcs.
673 void DeleteArcs(State *state, size_t n) {
674 if (cache_gc_ && (state->Flags() & kCacheInit)) {
675 cache_size_ -= n * sizeof(Arc);
677 store_.DeleteArcs(state, n);
680 // Deletes all cached states.
686 // Iterates over cached states (in an arbitrary order); only needed if GC is
688 bool Done() const { return store_.Done(); }
690 StateId Value() const { return store_.Value(); }
692 void Next() { store_.Next(); }
694 void Reset() { store_.Reset(); }
696 // Deletes current state and advances to next.
699 const auto *state = store_.GetState(Value());
700 if (state->Flags() & kCacheInit) {
701 cache_size_ -= sizeof(State) + state->NumArcs() * sizeof(Arc);
707 // Removes from the cache store (not referenced-counted and not the current)
708 // states that have not been accessed since the last GC until at most
709 // cache_fraction * cache_limit_ bytes are cached. If that fails to free
710 // enough, attempts to uncaching recently visited states as well. If still
711 // unable to free enough memory, then widens cache_limit_.
712 void GC(const State *current, bool free_recent, float cache_fraction = 0.666);
714 // Returns the current cache size in bytes or 0 if GC is disabled.
715 size_t CacheSize() const { return cache_size_; }
717 // Returns the cache limit in bytes.
718 size_t CacheLimit() const { return cache_limit_; }
721 static constexpr size_t kMinCacheLimit = 8096; // Minimum cache limit.
723 CacheStore store_; // Underlying store.
724 bool cache_gc_request_; // GC requested but possibly not yet enabled.
725 size_t cache_limit_; // Number of bytes allowed before GC.
726 bool cache_gc_; // GC enabled
727 size_t cache_size_; // Number of bytes cached.
730 template <class CacheStore>
731 void GCCacheStore<CacheStore>::GC(const State *current, bool free_recent,
732 float cache_fraction) {
733 if (!cache_gc_) return;
734 VLOG(2) << "GCCacheStore: Enter GC: object = "
735 << "(" << this << "), free recently cached = " << free_recent
736 << ", cache size = " << cache_size_
737 << ", cache frac = " << cache_fraction
738 << ", cache limit = " << cache_limit_ << "\n";
739 size_t cache_target = cache_fraction * cache_limit_;
741 while (!store_.Done()) {
742 auto *state = store_.GetMutableState(store_.Value());
743 if (cache_size_ > cache_target && state->RefCount() == 0 &&
744 (free_recent || !(state->Flags() & kCacheRecent)) && state != current) {
745 if (state->Flags() & kCacheInit) {
746 size_t size = sizeof(State) + state->NumArcs() * sizeof(Arc);
747 if (size < cache_size_) {
753 state->SetFlags(0, kCacheRecent);
757 if (!free_recent && cache_size_ > cache_target) { // Recurses on recent.
758 GC(current, true, cache_fraction);
759 } else if (cache_target > 0) { // Widens cache limit.
760 while (cache_size_ > cache_target) {
764 } else if (cache_size_ > 0) {
765 FSTERROR() << "GCCacheStore:GC: Unable to free all cached states";
767 VLOG(2) << "GCCacheStore: Exit GC: object = "
768 << "(" << this << "), free recently cached = " << free_recent
769 << ", cache size = " << cache_size_
770 << ", cache frac = " << cache_fraction
771 << ", cache limit = " << cache_limit_ << "\n";
774 template <class CacheStore>
775 constexpr size_t GCCacheStore<CacheStore>::kMinCacheLimit;
777 // This class is the default cache state and store used by CacheBaseImpl.
778 // It uses VectorCacheStore for storage decorated by FirstCacheStore
779 // and GCCacheStore to do (optional) garbage collection.
781 class DefaultCacheStore
782 : public GCCacheStore<FirstCacheStore<VectorCacheStore<CacheState<Arc>>>> {
784 explicit DefaultCacheStore(const CacheOptions &opts)
785 : GCCacheStore<FirstCacheStore<VectorCacheStore<CacheState<Arc>>>>(opts) {
791 // This class is used to cache FST elements stored in states of type State
792 // (see CacheState) with the flags used to indicate what has been cached. Use
793 // HasStart(), HasFinal(), and HasArcs() to determine if cached and SetStart(),
794 // SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note that you
795 // must set the final weight even if the state is non-final to mark it as
796 // cached. The state storage method and any garbage collection policy are
797 // determined by the cache store. If the store is passed in with the options,
798 // CacheBaseImpl takes ownership.
799 template <class State,
800 class CacheStore = DefaultCacheStore<typename State::Arc>>
801 class CacheBaseImpl : public FstImpl<typename State::Arc> {
803 using Arc = typename State::Arc;
804 using StateId = typename Arc::StateId;
805 using Weight = typename Arc::Weight;
807 using Store = CacheStore;
809 using FstImpl<Arc>::Type;
810 using FstImpl<Arc>::Properties;
812 explicit CacheBaseImpl(const CacheOptions &opts = CacheOptions())
814 cache_start_(kNoStateId),
816 min_unexpanded_state_id_(0),
817 max_expanded_state_id_(-1),
819 cache_limit_(opts.gc_limit),
820 cache_store_(new CacheStore(opts)),
821 new_cache_store_(true),
822 own_cache_store_(true) {}
824 explicit CacheBaseImpl(const CacheImplOptions<CacheStore> &opts)
826 cache_start_(kNoStateId),
828 min_unexpanded_state_id_(0),
829 max_expanded_state_id_(-1),
831 cache_limit_(opts.gc_limit),
832 cache_store_(opts.store ? opts.store : new CacheStore(CacheOptions(
833 opts.gc, opts.gc_limit))),
834 new_cache_store_(!opts.store),
835 own_cache_store_(opts.store ? opts.own_store : true) {}
837 // Preserve gc parameters. If preserve_cache is true, also preserves
839 CacheBaseImpl(const CacheBaseImpl<State, CacheStore> &impl,
840 bool preserve_cache = false)
843 cache_start_(kNoStateId),
845 min_unexpanded_state_id_(0),
846 max_expanded_state_id_(-1),
847 cache_gc_(impl.cache_gc_),
848 cache_limit_(impl.cache_limit_),
849 cache_store_(new CacheStore(CacheOptions(cache_gc_, cache_limit_))),
850 new_cache_store_(impl.new_cache_store_ || !preserve_cache),
851 own_cache_store_(true) {
852 if (preserve_cache) {
853 *cache_store_ = *impl.cache_store_;
854 has_start_ = impl.has_start_;
855 cache_start_ = impl.cache_start_;
856 nknown_states_ = impl.nknown_states_;
857 expanded_states_ = impl.expanded_states_;
858 min_unexpanded_state_id_ = impl.min_unexpanded_state_id_;
859 max_expanded_state_id_ = impl.max_expanded_state_id_;
863 ~CacheBaseImpl() override { if (own_cache_store_) delete cache_store_; }
865 void SetStart(StateId s) {
868 if (s >= nknown_states_) nknown_states_ = s + 1;
871 void SetFinal(StateId s, Weight weight) {
872 auto *state = cache_store_->GetMutableState(s);
873 state->SetFinal(std::move(weight));
874 static constexpr auto flags = kCacheFinal | kCacheRecent;
875 state->SetFlags(flags, flags);
878 // Disabled to ensure PushArc not AddArc is used in existing code
879 // TODO(sorenj): re-enable for backing store
881 // AddArc adds a single arc to a state and does incremental cache
882 // book-keeping. For efficiency, prefer PushArc and SetArcs below
884 void AddArc(StateId s, const Arc &arc) {
885 auto *state = cache_store_->GetMutableState(s);
886 cache_store_->AddArc(state, arc);
887 if (arc.nextstate >= nknown_states_)
888 nknown_states_ = arc.nextstate + 1;
890 static constexpr auto flags = kCacheArcs | kCacheRecent;
891 state->SetFlags(flags, flags);
895 // Adds a single arc to a state but delays cache book-keeping. SetArcs must
896 // be called when all PushArc calls at a state are complete. Do not mix with
898 void PushArc(StateId s, const Arc &arc) {
899 auto *state = cache_store_->GetMutableState(s);
903 // Marks arcs of a state as cached and does cache book-keeping after all
904 // calls to PushArc have been completed. Do not mix with calls to AddArc.
905 void SetArcs(StateId s) {
906 auto *state = cache_store_->GetMutableState(s);
907 cache_store_->SetArcs(state);
908 const auto narcs = state->NumArcs();
909 for (size_t a = 0; a < narcs; ++a) {
910 const auto &arc = state->GetArc(a);
911 if (arc.nextstate >= nknown_states_) nknown_states_ = arc.nextstate + 1;
914 static constexpr auto flags = kCacheArcs | kCacheRecent;
915 state->SetFlags(flags, flags);
918 void ReserveArcs(StateId s, size_t n) {
919 auto *state = cache_store_->GetMutableState(s);
920 state->ReserveArcs(n);
923 void DeleteArcs(StateId s) {
924 auto *state = cache_store_->GetMutableState(s);
925 cache_store_->DeleteArcs(state);
928 void DeleteArcs(StateId s, size_t n) {
929 auto *state = cache_store_->GetMutableState(s);
930 cache_store_->DeleteArcs(state, n);
935 min_unexpanded_state_id_ = 0;
936 max_expanded_state_id_ = -1;
938 cache_start_ = kNoStateId;
939 cache_store_->Clear();
942 // Is the start state cached?
943 bool HasStart() const {
944 if (!has_start_ && Properties(kError)) has_start_ = true;
948 // Is the final weight of the state cached?
949 bool HasFinal(StateId s) const {
950 const auto *state = cache_store_->GetState(s);
951 if (state && state->Flags() & kCacheFinal) {
952 state->SetFlags(kCacheRecent, kCacheRecent);
959 // Are arcs of the state cached?
960 bool HasArcs(StateId s) const {
961 const auto *state = cache_store_->GetState(s);
962 if (state && state->Flags() & kCacheArcs) {
963 state->SetFlags(kCacheRecent, kCacheRecent);
970 StateId Start() const { return cache_start_; }
972 Weight Final(StateId s) const {
973 const auto *state = cache_store_->GetState(s);
974 return state->Final();
977 size_t NumArcs(StateId s) const {
978 const auto *state = cache_store_->GetState(s);
979 return state->NumArcs();
982 size_t NumInputEpsilons(StateId s) const {
983 const auto *state = cache_store_->GetState(s);
984 return state->NumInputEpsilons();
987 size_t NumOutputEpsilons(StateId s) const {
988 const auto *state = cache_store_->GetState(s);
989 return state->NumOutputEpsilons();
992 // Provides information needed for generic arc iterator.
993 void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
994 const auto *state = cache_store_->GetState(s);
995 data->base = nullptr;
996 data->narcs = state->NumArcs();
997 data->arcs = state->Arcs();
998 data->ref_count = state->MutableRefCount();
999 state->IncrRefCount();
1002 // Number of known states.
1003 StateId NumKnownStates() const { return nknown_states_; }
1005 // Updates number of known states, taking into account the passed state ID.
1006 void UpdateNumKnownStates(StateId s) {
1007 if (s >= nknown_states_) nknown_states_ = s + 1;
1010 // Finds the mininum never-expanded state ID.
1011 StateId MinUnexpandedState() const {
1012 while (min_unexpanded_state_id_ <= max_expanded_state_id_ &&
1013 ExpandedState(min_unexpanded_state_id_)) {
1014 ++min_unexpanded_state_id_;
1016 return min_unexpanded_state_id_;
1019 // Returns maximum ever-expanded state ID.
1020 StateId MaxExpandedState() const { return max_expanded_state_id_; }
1022 void SetExpandedState(StateId s) {
1023 if (s > max_expanded_state_id_) max_expanded_state_id_ = s;
1024 if (s < min_unexpanded_state_id_) return;
1025 if (s == min_unexpanded_state_id_) ++min_unexpanded_state_id_;
1026 if (cache_gc_ || cache_limit_ == 0) {
1027 while (expanded_states_.size() <= s) expanded_states_.push_back(false);
1028 expanded_states_[s] = true;
1032 bool ExpandedState(StateId s) const {
1033 if (cache_gc_ || cache_limit_ == 0) {
1034 return expanded_states_[s];
1035 } else if (new_cache_store_) {
1036 return cache_store_->GetState(s) != nullptr;
1038 // If the cache was not created by this class, then the cached state needs
1039 // to be inspected to update nknown_states_.
1044 const CacheStore *GetCacheStore() const { return cache_store_; }
1046 CacheStore *GetCacheStore() { return cache_store_; }
1048 // Caching on/off switch, limit and size accessors.
1050 bool GetCacheGc() const { return cache_gc_; }
1052 size_t GetCacheLimit() const { return cache_limit_; }
1055 mutable bool has_start_; // Is the start state cached?
1056 StateId cache_start_; // ID of start state.
1057 StateId nknown_states_; // Number of known states.
1058 std::vector<bool> expanded_states_; // States that have been expanded.
1059 mutable StateId min_unexpanded_state_id_; // Minimum never-expanded state ID
1060 mutable StateId max_expanded_state_id_; // Maximum ever-expanded state ID
1061 bool cache_gc_; // GC enabled.
1062 size_t cache_limit_; // Number of bytes allowed before GC.
1063 CacheStore *cache_store_; // The store of cached states.
1064 bool new_cache_store_; // Was the store was created by class?
1065 bool own_cache_store_; // Is the store owned by class?
1067 CacheBaseImpl &operator=(const CacheBaseImpl &impl) = delete;
1070 // A CacheBaseImpl with the default cache state type.
1071 template <class Arc>
1072 class CacheImpl : public CacheBaseImpl<CacheState<Arc>> {
1074 using State = CacheState<Arc>;
1078 explicit CacheImpl(const CacheOptions &opts)
1079 : CacheBaseImpl<CacheState<Arc>>(opts) {}
1081 CacheImpl(const CacheImpl<Arc> &impl, bool preserve_cache = false)
1082 : CacheBaseImpl<State>(impl, preserve_cache) {}
1085 CacheImpl &operator=(const CacheImpl &impl) = delete;
1088 } // namespace internal
1090 // Use this to make a state iterator for a CacheBaseImpl-derived FST, which must
1091 // have Arc and Store types defined. Note this iterator only returns those
1092 // states reachable from the initial state, so consider implementing a
1093 // class-specific one.
1095 // This class may be derived from.
1096 template <class FST>
1097 class CacheStateIterator : public StateIteratorBase<typename FST::Arc> {
1099 using Arc = typename FST::Arc;
1100 using StateId = typename Arc::StateId;
1101 using Weight = typename Arc::Weight;
1103 using Store = typename FST::Store;
1104 using State = typename Store::State;
1105 using Impl = internal::CacheBaseImpl<State, Store>;
1107 CacheStateIterator(const FST &fst, Impl *impl)
1108 : fst_(fst), impl_(impl), s_(0) {
1109 fst_.Start(); // Forces start state.
1112 bool Done() const final {
1113 if (s_ < impl_->NumKnownStates()) return false;
1114 for (StateId u = impl_->MinUnexpandedState(); u < impl_->NumKnownStates();
1115 u = impl_->MinUnexpandedState()) {
1116 // Forces state expansion.
1117 ArcIterator<FST> aiter(fst_, u);
1118 aiter.SetFlags(kArcValueFlags, kArcValueFlags | kArcNoCache);
1119 for (; !aiter.Done(); aiter.Next()) {
1120 impl_->UpdateNumKnownStates(aiter.Value().nextstate);
1122 impl_->SetExpandedState(u);
1123 if (s_ < impl_->NumKnownStates()) return false;
1128 StateId Value() const final { return s_; }
1130 void Next() final { ++s_; }
1132 void Reset() final { s_ = 0; }
1140 // Used to make an arc iterator for a CacheBaseImpl-derived FST, which must
1141 // have Arc and State types defined.
1142 template <class FST>
1143 class CacheArcIterator {
1145 using Arc = typename FST::Arc;
1146 using StateId = typename Arc::StateId;
1147 using Weight = typename Arc::Weight;
1149 using Store = typename FST::Store;
1150 using State = typename Store::State;
1151 using Impl = internal::CacheBaseImpl<State, Store>;
1153 CacheArcIterator(Impl *impl, StateId s) : i_(0) {
1154 state_ = impl->GetCacheStore()->GetMutableState(s);
1155 state_->IncrRefCount();
1158 ~CacheArcIterator() { state_->DecrRefCount(); }
1160 bool Done() const { return i_ >= state_->NumArcs(); }
1162 const Arc &Value() const { return state_->GetArc(i_); }
1164 void Next() { ++i_; }
1166 size_t Position() const { return i_; }
1168 void Reset() { i_ = 0; }
1170 void Seek(size_t a) { i_ = a; }
1172 constexpr uint32 Flags() const { return kArcValueFlags; }
1174 void SetFlags(uint32 flags, uint32 mask) {}
1177 const State *state_;
1180 CacheArcIterator(const CacheArcIterator &) = delete;
1181 CacheArcIterator &operator=(const CacheArcIterator &) = delete;
1184 // Use this to make a mutable arc iterator for a CacheBaseImpl-derived FST,
1185 // which must have types Arc and Store defined.
1186 template <class FST>
1187 class CacheMutableArcIterator
1188 : public MutableArcIteratorBase<typename FST::Arc> {
1190 using Arc = typename FST::Arc;
1191 using StateId = typename Arc::StateId;
1192 using Weight = typename Arc::Weight;
1194 using Store = typename FST::Store;
1195 using State = typename Store::State;
1196 using Impl = internal::CacheBaseImpl<State, Store>;
1198 // User must call MutateCheck() in the constructor.
1199 CacheMutableArcIterator(Impl *impl, StateId s) : i_(0), s_(s), impl_(impl) {
1200 state_ = impl_->GetCacheStore()->GetMutableState(s_);
1201 state_->IncrRefCount();
1204 ~CacheMutableArcIterator() override { state_->DecrRefCount(); }
1206 bool Done() const final { return i_ >= state_->NumArcs(); }
1208 const Arc &Value() const final { return state_->GetArc(i_); }
1210 void Next() final { ++i_; }
1212 size_t Position() const final { return i_; }
1214 void Reset() final { i_ = 0; }
1216 void Seek(size_t a) final { i_ = a; }
1218 void SetValue(const Arc &arc) final { state_->SetArc(arc, i_); }
1220 uint32 Flags() const final { return kArcValueFlags; }
1222 void SetFlags(uint32, uint32) final {}
1230 CacheMutableArcIterator(const CacheMutableArcIterator &) = delete;
1231 CacheMutableArcIterator &operator=(const CacheMutableArcIterator &) = delete;
1234 // Wrap existing CacheStore implementation to use with ExpanderFst.
1235 template <class CacheStore>
1236 class ExpanderCacheStore {
1238 using State = typename CacheStore::State;
1239 using Arc = typename CacheStore::Arc;
1240 using StateId = typename Arc::StateId;
1241 using Weight = typename Arc::Weight;
1243 explicit ExpanderCacheStore(const CacheOptions &opts = CacheOptions())
1246 template <class Expander>
1247 State *FindOrExpand(Expander &expander, StateId s) { // NOLINT
1248 auto *state = store_.GetMutableState(s);
1249 if (state->Flags()) {
1250 state->SetFlags(kCacheRecent, kCacheRecent);
1252 StateBuilder builder(state);
1253 expander.Expand(s, &builder);
1254 state->SetFlags(kCacheFlags, kCacheFlags);
1255 store_.SetArcs(state);
1263 struct StateBuilder {
1266 explicit StateBuilder(State *state_) : state(state_) {}
1268 void AddArc(const Arc &arc) { state->PushArc(arc); }
1270 void SetFinal(Weight weight) { state->SetFinal(std::move(weight)); }
1276 #endif // FST_LIB_CACHE_H_