1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // An FST implementation that caches FST elements of a delayed computation.
6 #ifndef FST_LIB_CACHE_H_
7 #define FST_LIB_CACHE_H_
10 #include <unordered_map>
11 using std::unordered_map;
12 using std::unordered_multimap;
18 #include <fst/vector-fst.h>
21 DECLARE_bool(fst_default_cache_gc);
22 DECLARE_int64(fst_default_cache_gc_limit);
26 // Options for controlling caching behavior; higher level than CacheImplOptions.
28 bool gc; // Enables GC.
29 size_t gc_limit; // Number of bytes allowed before GC.
31 explicit CacheOptions(bool gc = FLAGS_fst_default_cache_gc,
32 size_t gc_limit = FLAGS_fst_default_cache_gc_limit)
33 : gc(gc), gc_limit(gc_limit) {}
36 // Options for controlling caching behavior, at a lower level than
37 // CacheOptions; templated on the cache store and allows passing the store.
38 template <class CacheStore>
39 struct CacheImplOptions {
40 bool gc; // Enables GC.
41 size_t gc_limit; // Number of bytes allowed before GC.
42 CacheStore *store; // Cache store.
43 bool own_store; // Should CacheImpl takes ownership of the store?
45 explicit CacheImplOptions(bool gc = FLAGS_fst_default_cache_gc,
46 size_t gc_limit = FLAGS_fst_default_cache_gc_limit,
47 CacheStore *store = nullptr)
48 : gc(gc), gc_limit(gc_limit), store(store), own_store(true) {}
50 explicit CacheImplOptions(const CacheOptions &opts)
51 : gc(opts.gc), gc_limit(opts.gc_limit), store(nullptr), own_store(true) {}
55 constexpr uint32 kCacheFinal = 0x0001; // Final weight has been cached.
56 constexpr uint32 kCacheArcs = 0x0002; // Arcs have been cached.
57 constexpr uint32 kCacheInit = 0x0004; // Initialized by GC.
58 constexpr uint32 kCacheRecent = 0x0008; // Visited since GC.
59 constexpr uint32 kCacheFlags =
60 kCacheFinal | kCacheArcs | kCacheInit | kCacheRecent;
62 // Cache state, with arcs stored in a per-state std::vector.
63 template <class A, class M = PoolAllocator<A>>
67 using Label = typename Arc::Label;
68 using StateId = typename Arc::StateId;
69 using Weight = typename Arc::Weight;
71 using ArcAllocator = M;
72 using StateAllocator =
73 typename ArcAllocator::template rebind<CacheState<A, M>>::other;
75 // Provides STL allocator for arcs.
76 explicit CacheState(const ArcAllocator &alloc)
77 : final_(Weight::Zero()),
84 CacheState(const CacheState<A> &state, const ArcAllocator &alloc)
85 : final_(state.Final()),
86 niepsilons_(state.NumInputEpsilons()),
87 noepsilons_(state.NumOutputEpsilons()),
88 arcs_(state.arcs_.begin(), state.arcs_.end(), alloc),
89 flags_(state.Flags()),
93 final_ = Weight::Zero();
101 Weight Final() const { return final_; }
103 size_t NumInputEpsilons() const { return niepsilons_; }
105 size_t NumOutputEpsilons() const { return noepsilons_; }
107 size_t NumArcs() const { return arcs_.size(); }
109 const Arc &GetArc(size_t n) const { return arcs_[n]; }
111 // Used by the ArcIterator<Fst<Arc>> efficient implementation.
112 const Arc *Arcs() const { return !arcs_.empty() ? &arcs_[0] : nullptr; }
114 // Accesses flags; used by the caller.
115 uint32 Flags() const { return flags_; }
117 // Accesses ref count; used by the caller.
118 int RefCount() const { return ref_count_; }
120 void SetFinal(Weight weight) { final_ = std::move(weight); }
122 void ReserveArcs(size_t n) { arcs_.reserve(n); }
124 // Adds one arc at a time with all needed book-keeping; use PushArc and
125 // SetArcs for a more efficient alternative.
126 void AddArc(const Arc &arc) {
127 arcs_.push_back(arc);
128 if (arc.ilabel == 0) ++niepsilons_;
129 if (arc.olabel == 0) ++noepsilons_;
132 // Adds one arc at a time with delayed book-keeping; finalize with SetArcs().
133 void PushArc(const Arc &arc) { arcs_.push_back(arc); }
135 // Finalizes arcs book-keeping; call only once.
137 for (const auto &arc : arcs_) {
138 if (arc.ilabel == 0) ++niepsilons_;
139 if (arc.olabel == 0) ++noepsilons_;
144 void SetArc(const Arc &arc, size_t n) {
145 if (arcs_[n].ilabel == 0) --niepsilons_;
146 if (arcs_[n].olabel == 0) --noepsilons_;
147 if (arc.ilabel == 0) ++niepsilons_;
148 if (arc.olabel == 0) ++noepsilons_;
159 void DeleteArcs(size_t n) {
160 for (size_t i = 0; i < n; ++i) {
161 if (arcs_.back().ilabel == 0) --niepsilons_;
162 if (arcs_.back().olabel == 0) --noepsilons_;
167 // Sets status flags; used by the caller.
168 void SetFlags(uint32 flags, uint32 mask) const {
173 // Mutates reference counts; used by the caller.
175 int IncrRefCount() const { return ++ref_count_; }
177 int DecrRefCount() const { return --ref_count_; }
179 // Used by the ArcIterator<Fst<Arc>> efficient implementation.
180 int *MutableRefCount() const { return &ref_count_; }
182 // Used for state class allocation.
183 void *operator new(size_t size, StateAllocator *alloc) {
184 return alloc->allocate(1);
187 // For state destruction and memory freeing.
188 static void Destroy(CacheState<Arc> *state, StateAllocator *alloc) {
190 state->~CacheState<Arc>();
191 alloc->deallocate(state, 1);
196 Weight final_; // Final weight.
197 size_t niepsilons_; // # of input epsilons.
198 size_t noepsilons_; // # of output epsilons.
199 std::vector<Arc, ArcAllocator> arcs_; // Arcs representation.
200 mutable uint32 flags_;
201 mutable int ref_count_; // If 0, available for GC.
204 // Cache store, allocating and storing states, providing a mapping from state
205 // IDs to cached states, and an iterator over these states. The state template
206 // argument must implement the CacheState interface. The state for a StateId s
207 // is constructed when requested by GetMutableState(s) if it is not yet stored.
208 // Initially, a state has a reference count of zero, but the user may increment
209 // or decrement this to control the time of destruction. In particular, a state
210 // is destroyed when:
212 // 1. This instance is destroyed, or
213 // 2. Clear() or Delete() is called, or
214 // 3. Possibly (implementation-dependently) when:
215 // - Garbage collection is enabled (as defined by opts.gc),
216 // - The cache store size exceeds the limits (as defined by opts.gc_limits),
217 // - The state's reference count is zero, and
218 // - The state is not the most recently requested state.
220 // template <class S>
221 // class CacheStore {
224 // using Arc = typename State::Arc;
225 // using StateId = typename Arc::StateId;
227 // // Required constructors/assignment operators.
228 // explicit CacheStore(const CacheOptions &opts);
230 // // Returns nullptr if state is not stored.
231 // const State *GetState(StateId s);
233 // // Creates state if state is not stored.
234 // State *GetMutableState(StateId s);
236 // // Similar to State::AddArc() but updates cache store book-keeping.
237 // void AddArc(State *state, const Arc &arc);
239 // // Similar to State::SetArcs() but updates cache store book-keeping; call
241 // void SetArcs(State *state);
243 // // Similar to State::DeleteArcs() but updates cache store book-keeping.
245 // void DeleteArcs(State *state);
247 // void DeleteArcs(State *state, size_t n);
249 // // Deletes all cached states.
252 // // Iterates over cached states (in an arbitrary order); only needed if
253 // // opts.gc is true.
254 // bool Done() const; // End of iteration.
255 // StateId Value() const; // Current state.
256 // void Next(); // Advances to next state (when !Done).
257 // void Reset(); // Returns to initial condition.
258 // void Delete(); // Deletes current state and advances to next.
261 // Container cache stores.
263 // This class uses a vector of pointers to states to store cached states.
265 class VectorCacheStore {
268 using Arc = typename State::Arc;
269 using StateId = typename Arc::StateId;
270 using StateList = std::list<StateId, PoolAllocator<StateId>>;
272 // Required constructors/assignment operators.
273 explicit VectorCacheStore(const CacheOptions &opts) : cache_gc_(opts.gc) {
278 VectorCacheStore(const VectorCacheStore<S> &store)
279 : cache_gc_(store.cache_gc_) {
284 ~VectorCacheStore() { Clear(); }
286 VectorCacheStore<State> &operator=(const VectorCacheStore<State> &store) {
287 if (this != &store) {
294 // Returns nullptr if state is not stored.
295 const State *GetState(StateId s) const {
296 return s < state_vec_.size() ? state_vec_[s] : nullptr;
299 // Creates state if state is not stored.
300 State *GetMutableState(StateId s) {
301 State *state = nullptr;
302 if (s >= state_vec_.size()) {
303 state_vec_.resize(s + 1, nullptr);
305 state = state_vec_[s];
308 state = new (&state_alloc_) State(arc_alloc_);
309 state_vec_[s] = state;
310 if (cache_gc_) state_list_.push_back(s);
315 // Similar to State::AddArc() but updates cache store book-keeping
316 void AddArc(State *state, const Arc &arc) { state->AddArc(arc); }
318 // Similar to State::SetArcs() but updates cache store book-keeping; call
320 void SetArcs(State *state) { state->SetArcs(); }
323 void DeleteArcs(State *state) { state->DeleteArcs(); }
325 // Deletes some arcs.
326 void DeleteArcs(State *state, size_t n) { state->DeleteArcs(n); }
328 // Deletes all cached states.
330 for (StateId s = 0; s < state_vec_.size(); ++s) {
331 State::Destroy(state_vec_[s], &state_alloc_);
337 // Iterates over cached states (in an arbitrary order); only works if GC is
338 // enabled (o.w. avoiding state_list_ overhead).
339 bool Done() const { return iter_ == state_list_.end(); }
341 StateId Value() const { return *iter_; }
343 void Next() { ++iter_; }
345 void Reset() { iter_ = state_list_.begin(); }
347 // Deletes current state and advances to next.
349 State::Destroy(state_vec_[*iter_], &state_alloc_);
350 state_vec_[*iter_] = nullptr;
351 state_list_.erase(iter_++);
355 void CopyStates(const VectorCacheStore<State> &store) {
357 state_vec_.reserve(store.state_vec_.size());
358 for (StateId s = 0; s < store.state_vec_.size(); ++s) {
359 State *state = nullptr;
360 const auto *store_state = store.state_vec_[s];
362 state = new (&state_alloc_) State(*store_state, arc_alloc_);
363 if (cache_gc_) state_list_.push_back(s);
365 state_vec_.push_back(state);
369 bool cache_gc_; // Supports iteration when true.
370 std::vector<State *> state_vec_; // Vector of states (or null).
371 StateList state_list_; // List of states.
372 typename StateList::iterator iter_; // State list iterator.
373 typename State::StateAllocator state_alloc_; // For state allocation.
374 typename State::ArcAllocator arc_alloc_; // For arc allocation.
377 // This class uses a hash map from state IDs to pointers to cached states.
379 class HashCacheStore {
382 using Arc = typename State::Arc;
383 using StateId = typename Arc::StateId;
386 std::unordered_map<StateId, State *, std::hash<StateId>,
387 std::equal_to<StateId>,
388 PoolAllocator<std::pair<const StateId, State *>>>;
390 // Required constructors/assignment operators.
391 explicit HashCacheStore(const CacheOptions &opts) {
396 HashCacheStore(const HashCacheStore<S> &store) {
401 ~HashCacheStore() { Clear(); }
403 HashCacheStore<State> &operator=(const HashCacheStore<State> &store) {
404 if (this != &store) {
411 // Returns nullptr if state is not stored.
412 const State *GetState(StateId s) const {
413 const auto it = state_map_.find(s);
414 return it != state_map_.end() ? it->second : nullptr;
417 // Creates state if state is not stored.
418 State *GetMutableState(StateId s) {
419 auto *&state = state_map_[s];
420 if (!state) state = new (&state_alloc_) State(arc_alloc_);
424 // Similar to State::AddArc() but updates cache store book-keeping.
425 void AddArc(State *state, const Arc &arc) { state->AddArc(arc); }
427 // Similar to State::SetArcs() but updates internal cache size; call only
429 void SetArcs(State *state) { state->SetArcs(); }
432 void DeleteArcs(State *state) { state->DeleteArcs(); }
434 // Deletes some arcs.
435 void DeleteArcs(State *state, size_t n) { state->DeleteArcs(n); }
437 // Deletes all cached states.
439 for (auto it = state_map_.begin(); it != state_map_.end(); ++it) {
440 State::Destroy(it->second, &state_alloc_);
445 // Iterates over cached states (in an arbitrary order).
446 bool Done() const { return iter_ == state_map_.end(); }
448 StateId Value() const { return iter_->first; }
450 void Next() { ++iter_; }
452 void Reset() { iter_ = state_map_.begin(); }
454 // Deletes current state and advances to next.
456 State::Destroy(iter_->second, &state_alloc_);
457 state_map_.erase(iter_++);
461 void CopyStates(const HashCacheStore<State> &store) {
463 for (auto it = store.state_map_.begin(); it != store.state_map_.end();
465 state_map_[it->first] =
466 new (&state_alloc_) State(*it->second, arc_alloc_);
470 StateMap state_map_; // Map from state ID to state.
471 typename StateMap::iterator iter_; // State map iterator.
472 typename State::StateAllocator state_alloc_; // For state allocation.
473 typename State::ArcAllocator arc_alloc_; // For arc allocation.
476 // Garbage-colllection cache stores.
478 // This class implements a simple garbage collection scheme when
479 // 'opts.gc_limit = 0'. In particular, the first cached state is reused for each
480 // new state so long as the reference count is zero on the to-be-reused state.
481 // Otherwise, the full underlying store is used. The caller can increment the
482 // reference count to inhibit the GC of in-use states (e.g., in an ArcIterator).
484 // The typical use case for this optimization is when a single pass over a
486 // FST is performed with only one-state expanded at a time.
487 template <class CacheStore>
488 class FirstCacheStore {
490 using State = typename CacheStore::State;
491 using Arc = typename State::Arc;
492 using StateId = typename Arc::StateId;
494 // Required constructors/assignment operators.
495 explicit FirstCacheStore(const CacheOptions &opts)
497 cache_gc_(opts.gc_limit == 0), // opts.gc ignored historically.
498 cache_first_state_id_(kNoStateId),
499 cache_first_state_(nullptr) {}
501 FirstCacheStore(const FirstCacheStore<CacheStore> &store)
502 : store_(store.store_),
503 cache_gc_(store.cache_gc_),
504 cache_first_state_id_(store.cache_first_state_id_),
505 cache_first_state_(store.cache_first_state_id_ != kNoStateId
506 ? store_.GetMutableState(0)
509 FirstCacheStore<CacheStore> &operator=(
510 const FirstCacheStore<CacheStore> &store) {
511 if (this != &store) {
512 store_ = store.store_;
513 cache_gc_ = store.cache_gc_;
514 cache_first_state_id_ = store.cache_first_state_id_;
515 cache_first_state_ = store.cache_first_state_id_ != kNoStateId
516 ? store_.GetMutableState(0)
522 // Returns nullptr if state is not stored.
523 const State *GetState(StateId s) const {
524 // store_ state 0 may hold first cached state; the rest are shifted by 1.
525 return s == cache_first_state_id_ ? cache_first_state_
526 : store_.GetState(s + 1);
529 // Creates state if state is not stored.
530 State *GetMutableState(StateId s) {
531 // store_ state 0 used to hold first cached state; the rest are shifted by
533 if (cache_first_state_id_ == s) {
534 return cache_first_state_; // Request for first cached state.
537 if (cache_first_state_id_ == kNoStateId) {
538 cache_first_state_id_ = s; // Sets first cached state.
539 cache_first_state_ = store_.GetMutableState(0);
540 cache_first_state_->SetFlags(kCacheInit, kCacheInit);
541 cache_first_state_->ReserveArcs(2 * kAllocSize);
542 return cache_first_state_;
543 } else if (cache_first_state_->RefCount() == 0) {
544 cache_first_state_id_ = s; // Updates first cached state.
545 cache_first_state_->Reset();
546 cache_first_state_->SetFlags(kCacheInit, kCacheInit);
547 return cache_first_state_;
548 } else { // Keeps first cached state.
549 cache_first_state_->SetFlags(0, kCacheInit); // Clears initialized bit.
550 cache_gc_ = false; // Disables GC.
553 auto *state = store_.GetMutableState(s + 1);
557 // Similar to State::AddArc() but updates cache store book-keeping.
558 void AddArc(State *state, const Arc &arc) { store_.AddArc(state, arc); }
560 // Similar to State::SetArcs() but updates internal cache size; call only
562 void SetArcs(State *state) { store_.SetArcs(state); }
565 void DeleteArcs(State *state) { store_.DeleteArcs(state); }
568 void DeleteArcs(State *state, size_t n) { store_.DeleteArcs(state, n); }
570 // Deletes all cached states
573 cache_first_state_id_ = kNoStateId;
574 cache_first_state_ = nullptr;
577 // Iterates over cached states (in an arbitrary order). Only needed if GC is
579 bool Done() const { return store_.Done(); }
581 StateId Value() const {
582 // store_ state 0 may hold first cached state; rest shifted + 1.
583 const auto s = store_.Value();
584 return s ? s - 1 : cache_first_state_id_;
587 void Next() { store_.Next(); }
589 void Reset() { store_.Reset(); }
591 // Deletes current state and advances to next.
593 if (Value() == cache_first_state_id_) {
594 cache_first_state_id_ = kNoStateId;
595 cache_first_state_ = nullptr;
601 CacheStore store_; // Underlying store.
602 bool cache_gc_; // GC enabled.
603 StateId cache_first_state_id_; // First cached state ID.
604 State *cache_first_state_; // First cached state.
607 // This class implements mark-sweep garbage collection on an underlying cache
608 // store. If GC is enabled, garbage collection of states is performed in a
609 // rough approximation of LRU order once when 'gc_limit' bytes is reached. The
610 // caller can increment the reference count to inhibit the GC of in-use state
611 // (e.g., in an ArcIterator). With GC enabled, the 'gc_limit' parameter allows
612 // the caller to trade-off time vs. space.
613 template <class CacheStore>
616 using State = typename CacheStore::State;
617 using Arc = typename State::Arc;
618 using StateId = typename Arc::StateId;
620 // Required constructors/assignment operators.
621 explicit GCCacheStore(const CacheOptions &opts)
623 cache_gc_request_(opts.gc),
624 cache_limit_(opts.gc_limit > kMinCacheLimit ? opts.gc_limit
629 // Returns 0 if state is not stored.
630 const State *GetState(StateId s) const { return store_.GetState(s); }
632 // Creates state if state is not stored
633 State *GetMutableState(StateId s) {
634 auto *state = store_.GetMutableState(s);
635 if (cache_gc_request_ && !(state->Flags() & kCacheInit)) {
636 state->SetFlags(kCacheInit, kCacheInit);
637 cache_size_ += sizeof(State) + state->NumArcs() * sizeof(Arc);
638 // GC is enabled once an uninited state (from underlying store) is seen.
640 if (cache_size_ > cache_limit_) GC(state, false);
645 // Similar to State::AddArc() but updates cache store book-keeping.
646 void AddArc(State *state, const Arc &arc) {
647 store_.AddArc(state, arc);
648 if (cache_gc_ && (state->Flags() & kCacheInit)) {
649 cache_size_ += sizeof(Arc);
650 if (cache_size_ > cache_limit_) GC(state, false);
654 // Similar to State::SetArcs() but updates internal cache size; call only
656 void SetArcs(State *state) {
657 store_.SetArcs(state);
658 if (cache_gc_ && (state->Flags() & kCacheInit)) {
659 cache_size_ += state->NumArcs() * sizeof(Arc);
660 if (cache_size_ > cache_limit_) GC(state, false);
665 void DeleteArcs(State *state) {
666 if (cache_gc_ && (state->Flags() & kCacheInit)) {
667 cache_size_ -= state->NumArcs() * sizeof(Arc);
669 store_.DeleteArcs(state);
672 // Deletes some arcs.
673 void DeleteArcs(State *state, size_t n) {
674 if (cache_gc_ && (state->Flags() & kCacheInit)) {
675 cache_size_ -= n * sizeof(Arc);
677 store_.DeleteArcs(state, n);
680 // Deletes all cached states.
686 // Iterates over cached states (in an arbitrary order); only needed if GC is
688 bool Done() const { return store_.Done(); }
690 StateId Value() const { return store_.Value(); }
692 void Next() { store_.Next(); }
694 void Reset() { store_.Reset(); }
696 // Deletes current state and advances to next.
699 const auto *state = store_.GetState(Value());
700 if (state->Flags() & kCacheInit) {
701 cache_size_ -= sizeof(State) + state->NumArcs() * sizeof(Arc);
707 // Removes from the cache store (not referenced-counted and not the current)
708 // states that have not been accessed since the last GC until at most
709 // cache_fraction * cache_limit_ bytes are cached. If that fails to free
710 // enough, attempts to uncaching recently visited states as well. If still
711 // unable to free enough memory, then widens cache_limit_.
712 void GC(const State *current, bool free_recent, float cache_fraction = 0.666);
714 // Returns the current cache size in bytes or 0 if GC is disabled.
715 size_t CacheSize() const { return cache_size_; }
717 // Returns the cache limit in bytes.
718 size_t CacheLimit() const { return cache_limit_; }
721 static constexpr size_t kMinCacheLimit = 8096; // Minimum cache limit.
723 CacheStore store_; // Underlying store.
724 bool cache_gc_request_; // GC requested but possibly not yet enabled.
725 size_t cache_limit_; // Number of bytes allowed before GC.
726 bool cache_gc_; // GC enabled
727 size_t cache_size_; // Number of bytes cached.
730 template <class CacheStore>
731 void GCCacheStore<CacheStore>::GC(const State *current, bool free_recent,
732 float cache_fraction) {
733 if (!cache_gc_) return;
734 VLOG(2) << "GCCacheStore: Enter GC: object = "
735 << "(" << this << "), free recently cached = " << free_recent
736 << ", cache size = " << cache_size_
737 << ", cache frac = " << cache_fraction
738 << ", cache limit = " << cache_limit_ << "\n";
739 size_t cache_target = cache_fraction * cache_limit_;
741 while (!store_.Done()) {
742 auto *state = store_.GetMutableState(store_.Value());
743 if (cache_size_ > cache_target && state->RefCount() == 0 &&
744 (free_recent || !(state->Flags() & kCacheRecent)) && state != current) {
745 if (state->Flags() & kCacheInit) {
746 size_t size = sizeof(State) + state->NumArcs() * sizeof(Arc);
747 CHECK_LE(size, cache_size_);
752 state->SetFlags(0, kCacheRecent);
756 if (!free_recent && cache_size_ > cache_target) { // Recurses on recent.
757 GC(current, true, cache_fraction);
758 } else if (cache_target > 0) { // Widens cache limit.
759 while (cache_size_ > cache_target) {
763 } else if (cache_size_ > 0) {
764 FSTERROR() << "GCCacheStore:GC: Unable to free all cached states";
766 VLOG(2) << "GCCacheStore: Exit GC: object = "
767 << "(" << this << "), free recently cached = " << free_recent
768 << ", cache size = " << cache_size_
769 << ", cache frac = " << cache_fraction
770 << ", cache limit = " << cache_limit_ << "\n";
773 template <class CacheStore>
774 constexpr size_t GCCacheStore<CacheStore>::kMinCacheLimit;
776 // This class is the default cache state and store used by CacheBaseImpl.
777 // It uses VectorCacheStore for storage decorated by FirstCacheStore
778 // and GCCacheStore to do (optional) garbage collection.
780 class DefaultCacheStore
781 : public GCCacheStore<FirstCacheStore<VectorCacheStore<CacheState<Arc>>>> {
783 explicit DefaultCacheStore(const CacheOptions &opts)
784 : GCCacheStore<FirstCacheStore<VectorCacheStore<CacheState<Arc>>>>(opts) {
790 // This class is used to cache FST elements stored in states of type State
791 // (see CacheState) with the flags used to indicate what has been cached. Use
792 // HasStart(), HasFinal(), and HasArcs() to determine if cached and SetStart(),
793 // SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note that you
794 // must set the final weight even if the state is non-final to mark it as
795 // cached. The state storage method and any garbage collection policy are
796 // determined by the cache store. If the store is passed in with the options,
797 // CacheBaseImpl takes ownership.
798 template <class State,
799 class CacheStore = DefaultCacheStore<typename State::Arc>>
800 class CacheBaseImpl : public FstImpl<typename State::Arc> {
802 using Arc = typename State::Arc;
803 using StateId = typename Arc::StateId;
804 using Weight = typename Arc::Weight;
806 using Store = CacheStore;
808 using FstImpl<Arc>::Type;
809 using FstImpl<Arc>::Properties;
811 explicit CacheBaseImpl(const CacheOptions &opts = CacheOptions())
813 cache_start_(kNoStateId),
815 min_unexpanded_state_id_(0),
816 max_expanded_state_id_(-1),
818 cache_limit_(opts.gc_limit),
819 cache_store_(new CacheStore(opts)),
820 new_cache_store_(true),
821 own_cache_store_(true) {}
823 explicit CacheBaseImpl(const CacheImplOptions<CacheStore> &opts)
825 cache_start_(kNoStateId),
827 min_unexpanded_state_id_(0),
828 max_expanded_state_id_(-1),
830 cache_limit_(opts.gc_limit),
831 cache_store_(opts.store ? opts.store : new CacheStore(CacheOptions(
832 opts.gc, opts.gc_limit))),
833 new_cache_store_(!opts.store),
834 own_cache_store_(opts.store ? opts.own_store : true) {}
836 // Preserve gc parameters. If preserve_cache is true, also preserves
838 CacheBaseImpl(const CacheBaseImpl<State, CacheStore> &impl,
839 bool preserve_cache = false)
842 cache_start_(kNoStateId),
844 min_unexpanded_state_id_(0),
845 max_expanded_state_id_(-1),
846 cache_gc_(impl.cache_gc_),
847 cache_limit_(impl.cache_limit_),
848 cache_store_(new CacheStore(CacheOptions(cache_gc_, cache_limit_))),
849 new_cache_store_(impl.new_cache_store_ || !preserve_cache),
850 own_cache_store_(true) {
851 if (preserve_cache) {
852 *cache_store_ = *impl.cache_store_;
853 has_start_ = impl.has_start_;
854 cache_start_ = impl.cache_start_;
855 nknown_states_ = impl.nknown_states_;
856 expanded_states_ = impl.expanded_states_;
857 min_unexpanded_state_id_ = impl.min_unexpanded_state_id_;
858 max_expanded_state_id_ = impl.max_expanded_state_id_;
862 ~CacheBaseImpl() override { if (own_cache_store_) delete cache_store_; }
864 void SetStart(StateId s) {
867 if (s >= nknown_states_) nknown_states_ = s + 1;
870 void SetFinal(StateId s, Weight weight) {
871 auto *state = cache_store_->GetMutableState(s);
872 state->SetFinal(std::move(weight));
873 static constexpr auto flags = kCacheFinal | kCacheRecent;
874 state->SetFlags(flags, flags);
877 // Disabled to ensure PushArc not AddArc is used in existing code
878 // TODO(sorenj): re-enable for backing store
880 // AddArc adds a single arc to a state and does incremental cache
881 // book-keeping. For efficiency, prefer PushArc and SetArcs below
883 void AddArc(StateId s, const Arc &arc) {
884 auto *state = cache_store_->GetMutableState(s);
885 cache_store_->AddArc(state, arc);
886 if (arc.nextstate >= nknown_states_)
887 nknown_states_ = arc.nextstate + 1;
889 static constexpr auto flags = kCacheArcs | kCacheRecent;
890 state->SetFlags(flags, flags);
894 // Adds a single arc to a state but delays cache book-keeping. SetArcs must
895 // be called when all PushArc calls at a state are complete. Do not mix with
897 void PushArc(StateId s, const Arc &arc) {
898 auto *state = cache_store_->GetMutableState(s);
902 // Marks arcs of a state as cached and does cache book-keeping after all
903 // calls to PushArc have been completed. Do not mix with calls to AddArc.
904 void SetArcs(StateId s) {
905 auto *state = cache_store_->GetMutableState(s);
906 cache_store_->SetArcs(state);
907 const auto narcs = state->NumArcs();
908 for (size_t a = 0; a < narcs; ++a) {
909 const auto &arc = state->GetArc(a);
910 if (arc.nextstate >= nknown_states_) nknown_states_ = arc.nextstate + 1;
913 static constexpr auto flags = kCacheArcs | kCacheRecent;
914 state->SetFlags(flags, flags);
917 void ReserveArcs(StateId s, size_t n) {
918 auto *state = cache_store_->GetMutableState(s);
919 state->ReserveArcs(n);
922 void DeleteArcs(StateId s) {
923 auto *state = cache_store_->GetMutableState(s);
924 cache_store_->DeleteArcs(state);
927 void DeleteArcs(StateId s, size_t n) {
928 auto *state = cache_store_->GetMutableState(s);
929 cache_store_->DeleteArcs(state, n);
934 min_unexpanded_state_id_ = 0;
935 max_expanded_state_id_ = -1;
937 cache_start_ = kNoStateId;
938 cache_store_->Clear();
941 // Is the start state cached?
942 bool HasStart() const {
943 if (!has_start_ && Properties(kError)) has_start_ = true;
947 // Is the final weight of the state cached?
948 bool HasFinal(StateId s) const {
949 const auto *state = cache_store_->GetState(s);
950 if (state && state->Flags() & kCacheFinal) {
951 state->SetFlags(kCacheRecent, kCacheRecent);
958 // Are arcs of the state cached?
959 bool HasArcs(StateId s) const {
960 const auto *state = cache_store_->GetState(s);
961 if (state && state->Flags() & kCacheArcs) {
962 state->SetFlags(kCacheRecent, kCacheRecent);
969 StateId Start() const { return cache_start_; }
971 Weight Final(StateId s) const {
972 const auto *state = cache_store_->GetState(s);
973 return state->Final();
976 size_t NumArcs(StateId s) const {
977 const auto *state = cache_store_->GetState(s);
978 return state->NumArcs();
981 size_t NumInputEpsilons(StateId s) const {
982 const auto *state = cache_store_->GetState(s);
983 return state->NumInputEpsilons();
986 size_t NumOutputEpsilons(StateId s) const {
987 const auto *state = cache_store_->GetState(s);
988 return state->NumOutputEpsilons();
991 // Provides information needed for generic arc iterator.
992 void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
993 const auto *state = cache_store_->GetState(s);
994 data->base = nullptr;
995 data->narcs = state->NumArcs();
996 data->arcs = state->Arcs();
997 data->ref_count = state->MutableRefCount();
998 state->IncrRefCount();
1001 // Number of known states.
1002 StateId NumKnownStates() const { return nknown_states_; }
1004 // Updates number of known states, taking into account the passed state ID.
1005 void UpdateNumKnownStates(StateId s) {
1006 if (s >= nknown_states_) nknown_states_ = s + 1;
1009 // Finds the mininum never-expanded state ID.
1010 StateId MinUnexpandedState() const {
1011 while (min_unexpanded_state_id_ <= max_expanded_state_id_ &&
1012 ExpandedState(min_unexpanded_state_id_)) {
1013 ++min_unexpanded_state_id_;
1015 return min_unexpanded_state_id_;
1018 // Returns maximum ever-expanded state ID.
1019 StateId MaxExpandedState() const { return max_expanded_state_id_; }
1021 void SetExpandedState(StateId s) {
1022 if (s > max_expanded_state_id_) max_expanded_state_id_ = s;
1023 if (s < min_unexpanded_state_id_) return;
1024 if (s == min_unexpanded_state_id_) ++min_unexpanded_state_id_;
1025 if (cache_gc_ || cache_limit_ == 0) {
1026 while (expanded_states_.size() <= s) expanded_states_.push_back(false);
1027 expanded_states_[s] = true;
1031 bool ExpandedState(StateId s) const {
1032 if (cache_gc_ || cache_limit_ == 0) {
1033 return expanded_states_[s];
1034 } else if (new_cache_store_) {
1035 return cache_store_->GetState(s) != nullptr;
1037 // If the cache was not created by this class, then the cached state needs
1038 // to be inspected to update nknown_states_.
1043 const CacheStore *GetCacheStore() const { return cache_store_; }
1045 CacheStore *GetCacheStore() { return cache_store_; }
1047 // Caching on/off switch, limit and size accessors.
1049 bool GetCacheGc() const { return cache_gc_; }
1051 size_t GetCacheLimit() const { return cache_limit_; }
1054 mutable bool has_start_; // Is the start state cached?
1055 StateId cache_start_; // ID of start state.
1056 StateId nknown_states_; // Number of known states.
1057 std::vector<bool> expanded_states_; // States that have been expanded.
1058 mutable StateId min_unexpanded_state_id_; // Minimum never-expanded state ID
1059 mutable StateId max_expanded_state_id_; // Maximum ever-expanded state ID
1060 bool cache_gc_; // GC enabled.
1061 size_t cache_limit_; // Number of bytes allowed before GC.
1062 CacheStore *cache_store_; // The store of cached states.
1063 bool new_cache_store_; // Was the store was created by class?
1064 bool own_cache_store_; // Is the store owned by class?
1066 CacheBaseImpl &operator=(const CacheBaseImpl &impl) = delete;
1069 // A CacheBaseImpl with the default cache state type.
1070 template <class Arc>
1071 class CacheImpl : public CacheBaseImpl<CacheState<Arc>> {
1073 using State = CacheState<Arc>;
1077 explicit CacheImpl(const CacheOptions &opts)
1078 : CacheBaseImpl<CacheState<Arc>>(opts) {}
1080 CacheImpl(const CacheImpl<Arc> &impl, bool preserve_cache = false)
1081 : CacheBaseImpl<State>(impl, preserve_cache) {}
1084 CacheImpl &operator=(const CacheImpl &impl) = delete;
1087 } // namespace internal
1089 // Use this to make a state iterator for a CacheBaseImpl-derived FST, which must
1090 // have Arc and Store types defined. Note this iterator only returns those
1091 // states reachable from the initial state, so consider implementing a
1092 // class-specific one.
1094 // This class may be derived from.
1095 template <class FST>
1096 class CacheStateIterator : public StateIteratorBase<typename FST::Arc> {
1098 using Arc = typename FST::Arc;
1099 using StateId = typename Arc::StateId;
1100 using Weight = typename Arc::Weight;
1102 using Store = typename FST::Store;
1103 using State = typename Store::State;
1104 using Impl = internal::CacheBaseImpl<State, Store>;
1106 CacheStateIterator(const FST &fst, Impl *impl)
1107 : fst_(fst), impl_(impl), s_(0) {
1108 fst_.Start(); // Forces start state.
1111 bool Done() const final {
1112 if (s_ < impl_->NumKnownStates()) return false;
1113 for (StateId u = impl_->MinUnexpandedState(); u < impl_->NumKnownStates();
1114 u = impl_->MinUnexpandedState()) {
1115 // Forces state expansion.
1116 ArcIterator<FST> aiter(fst_, u);
1117 aiter.SetFlags(kArcValueFlags, kArcValueFlags | kArcNoCache);
1118 for (; !aiter.Done(); aiter.Next()) {
1119 impl_->UpdateNumKnownStates(aiter.Value().nextstate);
1121 impl_->SetExpandedState(u);
1122 if (s_ < impl_->NumKnownStates()) return false;
1127 StateId Value() const final { return s_; }
1129 void Next() final { ++s_; }
1131 void Reset() final { s_ = 0; }
1139 // Used to make an arc iterator for a CacheBaseImpl-derived FST, which must
1140 // have Arc and State types defined.
1141 template <class FST>
1142 class CacheArcIterator {
1144 using Arc = typename FST::Arc;
1145 using StateId = typename Arc::StateId;
1146 using Weight = typename Arc::Weight;
1148 using Store = typename FST::Store;
1149 using State = typename Store::State;
1150 using Impl = internal::CacheBaseImpl<State, Store>;
1152 CacheArcIterator(Impl *impl, StateId s) : i_(0) {
1153 state_ = impl->GetCacheStore()->GetMutableState(s);
1154 state_->IncrRefCount();
1157 ~CacheArcIterator() { state_->DecrRefCount(); }
1159 bool Done() const { return i_ >= state_->NumArcs(); }
1161 const Arc &Value() const { return state_->GetArc(i_); }
1163 void Next() { ++i_; }
1165 size_t Position() const { return i_; }
1167 void Reset() { i_ = 0; }
1169 void Seek(size_t a) { i_ = a; }
1171 constexpr uint32 Flags() const { return kArcValueFlags; }
1173 void SetFlags(uint32 flags, uint32 mask) {}
1176 const State *state_;
1179 CacheArcIterator(const CacheArcIterator &) = delete;
1180 CacheArcIterator &operator=(const CacheArcIterator &) = delete;
1183 // Use this to make a mutable arc iterator for a CacheBaseImpl-derived FST,
1184 // which must have types Arc and Store defined.
1185 template <class FST>
1186 class CacheMutableArcIterator
1187 : public MutableArcIteratorBase<typename FST::Arc> {
1189 using Arc = typename FST::Arc;
1190 using StateId = typename Arc::StateId;
1191 using Weight = typename Arc::Weight;
1193 using Store = typename FST::Store;
1194 using State = typename Store::State;
1195 using Impl = internal::CacheBaseImpl<State, Store>;
1197 // User must call MutateCheck() in the constructor.
1198 CacheMutableArcIterator(Impl *impl, StateId s) : i_(0), s_(s), impl_(impl) {
1199 state_ = impl_->GetCacheStore()->GetMutableState(s_);
1200 state_->IncrRefCount();
1203 ~CacheMutableArcIterator() override { state_->DecrRefCount(); }
1205 bool Done() const final { return i_ >= state_->NumArcs(); }
1207 const Arc &Value() const final { return state_->GetArc(i_); }
1209 void Next() final { ++i_; }
1211 size_t Position() const final { return i_; }
1213 void Reset() final { i_ = 0; }
1215 void Seek(size_t a) final { i_ = a; }
1217 void SetValue(const Arc &arc) final { state_->SetArc(arc, i_); }
1219 uint32 Flags() const final { return kArcValueFlags; }
1221 void SetFlags(uint32, uint32) final {}
1229 CacheMutableArcIterator(const CacheMutableArcIterator &) = delete;
1230 CacheMutableArcIterator &operator=(const CacheMutableArcIterator &) = delete;
1233 // Wrap existing CacheStore implementation to use with ExpanderFst.
1234 template <class CacheStore>
1235 class ExpanderCacheStore {
1237 using State = typename CacheStore::State;
1238 using Arc = typename CacheStore::Arc;
1239 using StateId = typename Arc::StateId;
1240 using Weight = typename Arc::Weight;
1242 explicit ExpanderCacheStore(const CacheOptions &opts = CacheOptions())
1245 template <class Expander>
1246 State *FindOrExpand(Expander &expander, StateId s) { // NOLINT
1247 auto *state = store_.GetMutableState(s);
1248 if (state->Flags()) {
1249 state->SetFlags(kCacheRecent, kCacheRecent);
1251 StateBuilder builder(state);
1252 expander.Expand(s, &builder);
1253 state->SetFlags(kCacheFlags, kCacheFlags);
1254 store_.SetArcs(state);
1262 struct StateBuilder {
1265 explicit StateBuilder(State *state_) : state(state_) {}
1267 void AddArc(const Arc &arc) { state->PushArc(arc); }
1269 void SetFinal(Weight weight) { state->SetFinal(std::move(weight)); }
1275 #endif // FST_LIB_CACHE_H_