1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // NgramFst implements a n-gram language model based upon the LOUDS data
5 // structure. Please refer to "Unary Data Structures for Language Models"
6 // http://research.google.com/pubs/archive/37218.pdf
8 #ifndef FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
9 #define FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
19 #include <fst/compat.h>
22 #include <fst/extensions/ngram/bitmap-index.h>
23 #include <fst/fstlib.h>
24 #include <fst/mapped-file.h>
30 class NGramFstMatcher;
32 // Instance data containing mutable state for bookkeeping repeated access to
36 typedef typename A::Label Label;
37 typedef typename A::StateId StateId;
38 typedef typename A::Weight Weight;
44 std::vector<Label> context_;
45 StateId context_state_;
48 node_state_(kNoStateId),
49 context_state_(kNoStateId) {}
54 // Implementation class for LOUDS based NgramFst interface.
56 class NGramFstImpl : public FstImpl<A> {
57 using FstImpl<A>::SetInputSymbols;
58 using FstImpl<A>::SetOutputSymbols;
59 using FstImpl<A>::SetType;
60 using FstImpl<A>::WriteHeader;
62 friend class ArcIterator<NGramFst<A>>;
63 friend class NGramFstMatcher<A>;
66 using FstImpl<A>::InputSymbols;
67 using FstImpl<A>::SetProperties;
68 using FstImpl<A>::Properties;
71 typedef typename A::Label Label;
72 typedef typename A::StateId StateId;
73 typedef typename A::Weight Weight;
77 SetInputSymbols(nullptr);
78 SetOutputSymbols(nullptr);
79 SetProperties(kStaticProperties);
82 NGramFstImpl(const Fst<A> &fst, std::vector<StateId> *order_out);
84 explicit NGramFstImpl(const Fst<A> &fst) : NGramFstImpl(fst, nullptr) {}
86 NGramFstImpl(const NGramFstImpl &other) {
87 FSTERROR() << "Copying NGramFst Impls is not supported, use safe = false.";
88 SetProperties(kError, kError);
91 ~NGramFstImpl() override {
97 static NGramFstImpl<A> *Read(std::istream &strm, // NOLINT
98 const FstReadOptions &opts) {
99 NGramFstImpl<A> *impl = new NGramFstImpl();
101 if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return 0;
102 uint64 num_states, num_futures, num_final;
103 const size_t offset =
104 sizeof(num_states) + sizeof(num_futures) + sizeof(num_final);
105 // Peek at num_states and num_futures to see how much more needs to be read.
106 strm.read(reinterpret_cast<char *>(&num_states), sizeof(num_states));
107 strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures));
108 strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final));
109 size_t size = Storage(num_states, num_futures, num_final);
110 MappedFile *data_region = MappedFile::Allocate(size);
111 char *data = reinterpret_cast<char *>(data_region->mutable_data());
112 // Copy num_states, num_futures and num_final back into data.
113 memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states));
114 memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures),
115 sizeof(num_futures));
116 memcpy(data + sizeof(num_states) + sizeof(num_futures),
117 reinterpret_cast<char *>(&num_final), sizeof(num_final));
118 strm.read(data + offset, size - offset);
123 impl->Init(data, false, data_region);
127 bool Write(std::ostream &strm, // NOLINT
128 const FstWriteOptions &opts) const {
130 hdr.SetStart(Start());
131 hdr.SetNumStates(num_states_);
132 WriteHeader(strm, opts, kFileVersion, &hdr);
133 strm.write(data_, StorageSize());
137 StateId Start() const { return start_; }
139 Weight Final(StateId state) const {
140 if (final_index_.Get(state)) {
141 return final_probs_[final_index_.Rank1(state)];
143 return Weight::Zero();
147 size_t NumArcs(StateId state, NGramFstInst<A> *inst = nullptr) const {
148 if (inst == nullptr) {
149 const std::pair<size_t, size_t> zeros =
150 (state == 0) ? select_root_ : future_index_.Select0s(state);
151 return zeros.second - zeros.first - 1;
153 SetInstFuture(state, inst);
154 return inst->num_futures_ + ((state == 0) ? 0 : 1);
157 size_t NumInputEpsilons(StateId state) const {
158 // State 0 has no parent, thus no backoff.
159 if (state == 0) return 0;
163 size_t NumOutputEpsilons(StateId state) const {
164 return NumInputEpsilons(state);
167 StateId NumStates() const { return num_states_; }
169 void InitStateIterator(StateIteratorData<A> *data) const {
171 data->nstates = num_states_;
174 static size_t Storage(uint64 num_states, uint64 num_futures,
180 sizeof(num_states) + sizeof(num_futures) + sizeof(num_final);
182 sizeof(b64) * (BitmapIndex::StorageSize(num_states * 2 + 1) +
183 BitmapIndex::StorageSize(num_futures + num_states + 1) +
184 BitmapIndex::StorageSize(num_states));
185 offset += (num_states + 1) * sizeof(label) + num_futures * sizeof(label);
186 // Pad for alignemnt, see
187 // http://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding
188 offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
189 offset += (num_states + 1) * sizeof(weight) + num_final * sizeof(weight) +
190 (num_futures + 1) * sizeof(weight);
194 void SetInstFuture(StateId state, NGramFstInst<A> *inst) const {
195 if (inst->state_ != state) {
196 inst->state_ = state;
197 const std::pair<size_t, size_t> zeros = future_index_.Select0s(state);
198 inst->num_futures_ = zeros.second - zeros.first - 1;
199 inst->offset_ = future_index_.Rank1(zeros.first + 1);
203 void SetInstNode(NGramFstInst<A> *inst) const {
204 if (inst->node_state_ != inst->state_) {
205 inst->node_state_ = inst->state_;
206 inst->node_ = context_index_.Select1(inst->state_);
210 void SetInstContext(NGramFstInst<A> *inst) const {
212 if (inst->context_state_ != inst->state_) {
213 inst->context_state_ = inst->state_;
214 inst->context_.clear();
215 size_t node = inst->node_;
217 inst->context_.push_back(context_words_[context_index_.Rank1(node)]);
218 node = context_index_.Select1(context_index_.Rank0(node) - 1);
223 // Access to the underlying representation
224 const char *GetData(size_t *data_size) const {
225 *data_size = StorageSize();
229 void Init(const char *data, bool owned, MappedFile *file = nullptr);
231 const std::vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const {
232 SetInstFuture(s, inst);
233 SetInstContext(inst);
234 return inst->context_;
237 size_t StorageSize() const {
238 return Storage(num_states_, num_futures_, num_final_);
241 void GetStates(const std::vector<Label> &context,
242 std::vector<StateId> *states) const;
245 StateId Transition(const std::vector<Label> &context, Label future) const;
247 // Properties always true for this Fst class.
248 static const uint64 kStaticProperties =
249 kAcceptor | kIDeterministic | kODeterministic | kEpsilons | kIEpsilons |
250 kOEpsilons | kILabelSorted | kOLabelSorted | kWeighted | kCyclic |
251 kInitialAcyclic | kNotTopSorted | kAccessible | kCoAccessible |
252 kNotString | kExpanded;
253 // Current file format version.
254 static const int kFileVersion = 4;
255 // Minimum file format version supported.
256 static const int kMinFileVersion = 4;
258 std::unique_ptr<MappedFile> data_region_;
259 const char *data_ = nullptr;
260 bool owned_ = false; // True if we own data_
261 StateId start_ = fst::kNoStateId;
262 uint64 num_states_ = 0;
263 uint64 num_futures_ = 0;
264 uint64 num_final_ = 0;
265 std::pair<size_t, size_t> select_root_;
266 const Label *root_children_ = nullptr;
267 // borrowed references
268 const uint64 *context_ = nullptr;
269 const uint64 *future_ = nullptr;
270 const uint64 *final_ = nullptr;
271 const Label *context_words_ = nullptr;
272 const Label *future_words_ = nullptr;
273 const Weight *backoff_ = nullptr;
274 const Weight *final_probs_ = nullptr;
275 const Weight *future_probs_ = nullptr;
276 BitmapIndex context_index_;
277 BitmapIndex future_index_;
278 BitmapIndex final_index_;
281 template <typename A>
282 inline void NGramFstImpl<A>::GetStates(
283 const std::vector<Label> &context,
284 std::vector<typename A::StateId> *states) const {
286 states->push_back(0);
287 typename std::vector<Label>::const_reverse_iterator cit = context.rbegin();
288 const Label *children = root_children_;
289 size_t num_children = select_root_.second - 2;
290 const Label *loc = std::lower_bound(children, children + num_children, *cit);
291 if (loc == children + num_children || *loc != *cit) return;
292 size_t node = 2 + loc - children;
293 states->push_back(context_index_.Rank1(node));
294 if (context.size() == 1) return;
295 size_t node_rank = context_index_.Rank1(node);
296 std::pair<size_t, size_t> zeros =
297 node_rank == 0 ? select_root_ : context_index_.Select0s(node_rank);
298 size_t first_child = zeros.first + 1;
300 if (context_index_.Get(first_child) != false) {
301 size_t last_child = zeros.second - 1;
302 while (cit != context.rend()) {
303 children = context_words_ + context_index_.Rank1(first_child);
304 loc = std::lower_bound(children, children + last_child - first_child + 1,
306 if (loc == children + last_child - first_child + 1 || *loc != *cit) {
310 node = first_child + loc - children;
311 states->push_back(context_index_.Rank1(node));
312 node_rank = context_index_.Rank1(node);
314 node_rank == 0 ? select_root_ : context_index_.Select0s(node_rank);
315 first_child = zeros.first + 1;
316 if (context_index_.Get(first_child) == false) break;
317 last_child = zeros.second - 1;
322 } // namespace internal
324 /*****************************************************************************/
326 class NGramFst : public ImplToExpandedFst<internal::NGramFstImpl<A>> {
327 friend class ArcIterator<NGramFst<A>>;
328 friend class NGramFstMatcher<A>;
332 typedef typename A::StateId StateId;
333 typedef typename A::Label Label;
334 typedef typename A::Weight Weight;
335 typedef internal::NGramFstImpl<A> Impl;
337 explicit NGramFst(const Fst<A> &dst)
338 : ImplToExpandedFst<Impl>(std::make_shared<Impl>(dst, nullptr)) {}
340 NGramFst(const Fst<A> &fst, std::vector<StateId> *order_out)
341 : ImplToExpandedFst<Impl>(std::make_shared<Impl>(fst, order_out)) {}
343 // Because the NGramFstImpl is a const stateless data structure, there
344 // is never a need to do anything beside copy the reference.
345 NGramFst(const NGramFst<A> &fst, bool safe = false)
346 : ImplToExpandedFst<Impl>(fst, false) {}
348 NGramFst() : ImplToExpandedFst<Impl>(std::make_shared<Impl>()) {}
350 // Non-standard constructor to initialize NGramFst directly from data.
351 NGramFst(const char *data, bool owned)
352 : ImplToExpandedFst<Impl>(std::make_shared<Impl>()) {
353 GetMutableImpl()->Init(data, owned, nullptr);
356 // Get method that gets the data associated with Init().
357 const char *GetData(size_t *data_size) const {
358 return GetImpl()->GetData(data_size);
361 const std::vector<Label> GetContext(StateId s) const {
362 return GetImpl()->GetContext(s, &inst_);
365 // Consumes as much as possible of context from right to left, returns the
366 // the states corresponding to the increasingly conditioned input sequence.
367 void GetStates(const std::vector<Label> &context,
368 std::vector<StateId> *state) const {
369 return GetImpl()->GetStates(context, state);
372 size_t NumArcs(StateId s) const override {
373 return GetImpl()->NumArcs(s, &inst_);
376 NGramFst<A> *Copy(bool safe = false) const override {
377 return new NGramFst(*this, safe);
380 static NGramFst<A> *Read(std::istream &strm, const FstReadOptions &opts) {
381 Impl *impl = Impl::Read(strm, opts);
382 return impl ? new NGramFst<A>(std::shared_ptr<Impl>(impl)) : nullptr;
385 static NGramFst<A> *Read(const string &filename) {
386 if (!filename.empty()) {
387 std::ifstream strm(filename,
388 std::ios_base::in | std::ios_base::binary);
390 LOG(ERROR) << "NGramFst::Read: Can't open file: " << filename;
393 return Read(strm, FstReadOptions(filename));
395 return Read(std::cin, FstReadOptions("standard input"));
399 bool Write(std::ostream &strm, const FstWriteOptions &opts) const override {
400 return GetImpl()->Write(strm, opts);
403 bool Write(const string &filename) const override {
404 return Fst<A>::WriteFile(filename);
407 inline void InitStateIterator(StateIteratorData<A> *data) const override {
408 GetImpl()->InitStateIterator(data);
411 inline void InitArcIterator(StateId s,
412 ArcIteratorData<A> *data) const override;
414 MatcherBase<A> *InitMatcher(MatchType match_type) const override {
415 return new NGramFstMatcher<A>(*this, match_type);
418 size_t StorageSize() const { return GetImpl()->StorageSize(); }
420 static bool HasRequiredProps(const Fst<A> &fst) {
422 kAcceptor | kIDeterministic | kILabelSorted | kIEpsilons | kAccessible;
423 return fst.Properties(props, true) == props;
426 static bool HasRequiredStructure(const Fst<A> &fst) {
427 if (!HasRequiredProps(fst)) {
430 typename A::StateId unigram = fst.Start();
431 while (true) { // Follows epsilon arc chain to find unigram state.
432 if (unigram == fst::kNoStateId) return false; // No unigram state.
433 typename fst::ArcIterator<Fst<A>> aiter(fst, unigram);
434 if (aiter.Done() || aiter.Value().ilabel != 0) break;
435 unigram = aiter.Value().nextstate;
438 // Other requirement: all states other than unigram an epsilon arc.
439 for (fst::StateIterator<Fst<A>> siter(fst); !siter.Done();
441 const typename A::StateId &state = siter.Value();
442 fst::ArcIterator<Fst<A>> aiter(fst, state);
443 if (state != unigram) {
444 if (aiter.Done()) return false;
445 if (aiter.Value().ilabel != 0) return false;
447 if (!aiter.Done() && aiter.Value().ilabel == 0) return false;
454 using ImplToExpandedFst<Impl, ExpandedFst<A>>::GetImpl;
455 using ImplToExpandedFst<Impl, ExpandedFst<A>>::GetMutableImpl;
457 explicit NGramFst(std::shared_ptr<Impl> impl)
458 : ImplToExpandedFst<Impl>(impl) {}
460 mutable NGramFstInst<A> inst_;
464 inline void NGramFst<A>::InitArcIterator(StateId s,
465 ArcIteratorData<A> *data) const {
466 GetImpl()->SetInstFuture(s, &inst_);
467 GetImpl()->SetInstNode(&inst_);
468 data->base = new ArcIterator<NGramFst<A>>(*this, s);
473 template <typename A>
474 NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst,
475 std::vector<StateId> *order_out) {
477 typedef typename Arc::Label Label;
478 typedef typename Arc::Weight Weight;
479 typedef typename Arc::StateId StateId;
481 SetInputSymbols(fst.InputSymbols());
482 SetOutputSymbols(fst.OutputSymbols());
483 SetProperties(kStaticProperties);
485 // Check basic requirements for an OpenGrm language model Fst.
486 if (!NGramFst<A>::HasRequiredProps(fst)) {
487 FSTERROR() << "NGramFst only accepts OpenGrm language models as input";
488 SetProperties(kError, kError);
492 int64 num_states = CountStates(fst);
493 Label *context = new Label[num_states];
495 // Find the unigram state by starting from the start state, following
497 StateId unigram = fst.Start();
499 if (unigram == kNoStateId) {
500 FSTERROR() << "Could not identify unigram state";
501 SetProperties(kError, kError);
504 ArcIterator<Fst<A>> aiter(fst, unigram);
506 LOG(WARNING) << "Unigram state " << unigram << " has no arcs.";
509 if (aiter.Value().ilabel != 0) break;
510 unigram = aiter.Value().nextstate;
513 // Each state's context is determined by the subtree it is under from the
515 std::queue<std::pair<StateId, Label>> label_queue;
516 std::vector<bool> visited(num_states);
517 // Force an epsilon link to the start state.
518 label_queue.push(std::make_pair(fst.Start(), 0));
519 for (ArcIterator<Fst<A>> aiter(fst, unigram); !aiter.Done(); aiter.Next()) {
521 std::make_pair(aiter.Value().nextstate, aiter.Value().ilabel));
523 // investigate states in breadth first fashion to assign context words.
524 while (!label_queue.empty()) {
525 std::pair<StateId, Label> &now = label_queue.front();
526 if (!visited[now.first]) {
527 context[now.first] = now.second;
528 visited[now.first] = true;
529 for (ArcIterator<Fst<A>> aiter(fst, now.first); !aiter.Done();
531 const Arc &arc = aiter.Value();
532 if (arc.ilabel != 0) {
533 label_queue.push(std::make_pair(arc.nextstate, now.second));
541 // The arc from the start state should be assigned an epsilon to put it
542 // in front of the all other labels (which makes Start state 1 after
543 // unigram which is state 0).
544 context[fst.Start()] = 0;
546 // Build the tree of contexts fst by reversing the epsilon arcs from fst.
547 VectorFst<Arc> context_fst;
548 uint64 num_final = 0;
549 for (int i = 0; i < num_states; ++i) {
550 if (fst.Final(i) != Weight::Zero()) {
553 context_fst.SetFinal(context_fst.AddState(), fst.Final(i));
555 context_fst.SetStart(unigram);
556 context_fst.SetInputSymbols(fst.InputSymbols());
557 context_fst.SetOutputSymbols(fst.OutputSymbols());
558 int64 num_context_arcs = 0;
559 int64 num_futures = 0;
560 for (StateIterator<Fst<A>> siter(fst); !siter.Done(); siter.Next()) {
561 const StateId &state = siter.Value();
562 num_futures += fst.NumArcs(state) - fst.NumInputEpsilons(state);
563 ArcIterator<Fst<A>> aiter(fst, state);
565 const Arc &arc = aiter.Value();
566 // this arc goes from state to arc.nextstate, so create an arc from
567 // arc.nextstate to state to reverse it.
568 if (arc.ilabel == 0) {
569 context_fst.AddArc(arc.nextstate, Arc(context[state], context[state],
575 if (num_context_arcs != context_fst.NumStates() - 1) {
576 FSTERROR() << "Number of contexts arcs != number of states - 1";
577 SetProperties(kError, kError);
580 if (context_fst.NumStates() != num_states) {
581 FSTERROR() << "Number of contexts != number of states";
582 SetProperties(kError, kError);
585 int64 context_props =
586 context_fst.Properties(kIDeterministic | kILabelSorted, true);
587 if (!(context_props & kIDeterministic)) {
588 FSTERROR() << "Input Fst is not structured properly";
589 SetProperties(kError, kError);
592 if (!(context_props & kILabelSorted)) {
593 ArcSort(&context_fst, ILabelCompare<Arc>());
600 Label label = kNoLabel;
601 const size_t storage = Storage(num_states, num_futures, num_final);
602 MappedFile *data_region = MappedFile::Allocate(storage);
603 char *data = reinterpret_cast<char *>(data_region->mutable_data());
604 memset(data, 0, storage);
606 memcpy(data + offset, reinterpret_cast<char *>(&num_states),
608 offset += sizeof(num_states);
609 memcpy(data + offset, reinterpret_cast<char *>(&num_futures),
610 sizeof(num_futures));
611 offset += sizeof(num_futures);
612 memcpy(data + offset, reinterpret_cast<char *>(&num_final),
614 offset += sizeof(num_final);
615 uint64 *context_bits = reinterpret_cast<uint64 *>(data + offset);
616 offset += BitmapIndex::StorageSize(num_states * 2 + 1) * sizeof(b64);
617 uint64 *future_bits = reinterpret_cast<uint64 *>(data + offset);
619 BitmapIndex::StorageSize(num_futures + num_states + 1) * sizeof(b64);
620 uint64 *final_bits = reinterpret_cast<uint64 *>(data + offset);
621 offset += BitmapIndex::StorageSize(num_states) * sizeof(b64);
622 Label *context_words = reinterpret_cast<Label *>(data + offset);
623 offset += (num_states + 1) * sizeof(label);
624 Label *future_words = reinterpret_cast<Label *>(data + offset);
625 offset += num_futures * sizeof(label);
626 offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
627 Weight *backoff = reinterpret_cast<Weight *>(data + offset);
628 offset += (num_states + 1) * sizeof(weight);
629 Weight *final_probs = reinterpret_cast<Weight *>(data + offset);
630 offset += num_final * sizeof(weight);
631 Weight *future_probs = reinterpret_cast<Weight *>(data + offset);
632 int64 context_arc = 0, future_arc = 0, context_bit = 0, future_bit = 0,
636 BitmapIndex::Set(context_bits, context_bit++);
638 context_words[context_arc] = label;
639 backoff[context_arc] = Weight::Zero();
645 order_out->resize(num_states);
648 std::queue<StateId> context_q;
649 context_q.push(context_fst.Start());
650 StateId state_number = 0;
651 while (!context_q.empty()) {
652 const StateId &state = context_q.front();
654 (*order_out)[state] = state_number;
657 const Weight final_weight = context_fst.Final(state);
658 if (final_weight != Weight::Zero()) {
659 BitmapIndex::Set(final_bits, state_number);
660 final_probs[final_bit] = final_weight;
664 for (ArcIterator<VectorFst<A>> aiter(context_fst, state); !aiter.Done();
666 const Arc &arc = aiter.Value();
667 context_words[context_arc] = arc.ilabel;
668 backoff[context_arc] = arc.weight;
670 BitmapIndex::Set(context_bits, context_bit++);
671 context_q.push(arc.nextstate);
675 for (ArcIterator<Fst<A>> aiter(fst, state); !aiter.Done(); aiter.Next()) {
676 const Arc &arc = aiter.Value();
677 if (arc.ilabel != 0) {
678 future_words[future_arc] = arc.ilabel;
679 future_probs[future_arc] = arc.weight;
681 BitmapIndex::Set(future_bits, future_bit++);
689 if ((state_number != num_states) || (context_bit != num_states * 2 + 1) ||
690 (context_arc != num_states) || (future_arc != num_futures) ||
691 (future_bit != num_futures + num_states + 1) ||
692 (final_bit != num_final)) {
693 FSTERROR() << "Structure problems detected during construction";
694 SetProperties(kError, kError);
698 Init(data, false, data_region);
701 template <typename A>
702 inline void NGramFstImpl<A>::Init(const char *data, bool owned,
703 MappedFile *data_region) {
707 data_region_.reset(data_region);
711 num_states_ = *(reinterpret_cast<const uint64 *>(data_ + offset));
712 offset += sizeof(num_states_);
713 num_futures_ = *(reinterpret_cast<const uint64 *>(data_ + offset));
714 offset += sizeof(num_futures_);
715 num_final_ = *(reinterpret_cast<const uint64 *>(data_ + offset));
716 offset += sizeof(num_final_);
718 size_t context_bits = num_states_ * 2 + 1;
719 size_t future_bits = num_futures_ + num_states_ + 1;
720 context_ = reinterpret_cast<const uint64 *>(data_ + offset);
721 offset += BitmapIndex::StorageSize(context_bits) * sizeof(bits);
722 future_ = reinterpret_cast<const uint64 *>(data_ + offset);
723 offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits);
724 final_ = reinterpret_cast<const uint64 *>(data_ + offset);
725 offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits);
726 context_words_ = reinterpret_cast<const Label *>(data_ + offset);
727 offset += (num_states_ + 1) * sizeof(*context_words_);
728 future_words_ = reinterpret_cast<const Label *>(data_ + offset);
729 offset += num_futures_ * sizeof(*future_words_);
730 offset = (offset + sizeof(*backoff_) - 1) & ~(sizeof(*backoff_) - 1);
731 backoff_ = reinterpret_cast<const Weight *>(data_ + offset);
732 offset += (num_states_ + 1) * sizeof(*backoff_);
733 final_probs_ = reinterpret_cast<const Weight *>(data_ + offset);
734 offset += num_final_ * sizeof(*final_probs_);
735 future_probs_ = reinterpret_cast<const Weight *>(data_ + offset);
737 context_index_.BuildIndex(context_, context_bits);
738 future_index_.BuildIndex(future_, future_bits);
739 final_index_.BuildIndex(final_, num_states_);
741 select_root_ = context_index_.Select0s(0);
742 if (context_index_.Rank1(0) != 0 || select_root_.first != 1 ||
743 context_index_.Get(2) == false) {
744 FSTERROR() << "Malformed file";
745 SetProperties(kError, kError);
748 root_children_ = context_words_ + context_index_.Rank1(2);
752 template <typename A>
753 inline typename A::StateId NGramFstImpl<A>::Transition(
754 const std::vector<Label> &context, Label future) const {
755 const Label *children = root_children_;
756 size_t num_children = select_root_.second - 2;
758 std::lower_bound(children, children + num_children, future);
759 if (loc == children + num_children || *loc != future) {
760 return context_index_.Rank1(0);
762 size_t node = 2 + loc - children;
763 size_t node_rank = context_index_.Rank1(node);
764 std::pair<size_t, size_t> zeros =
765 (node_rank == 0) ? select_root_ : context_index_.Select0s(node_rank);
766 size_t first_child = zeros.first + 1;
767 if (context_index_.Get(first_child) == false) {
768 return context_index_.Rank1(node);
770 size_t last_child = zeros.second - 1;
771 for (int word = context.size() - 1; word >= 0; --word) {
772 children = context_words_ + context_index_.Rank1(first_child);
773 loc = std::lower_bound(children, children + last_child - first_child + 1,
775 if (loc == children + last_child - first_child + 1 ||
776 *loc != context[word]) {
779 node = first_child + loc - children;
780 node_rank = context_index_.Rank1(node);
782 (node_rank == 0) ? select_root_ : context_index_.Select0s(node_rank);
783 first_child = zeros.first + 1;
784 if (context_index_.Get(first_child) == false) break;
785 last_child = zeros.second - 1;
787 return context_index_.Rank1(node);
790 } // namespace internal
792 /*****************************************************************************/
794 class NGramFstMatcher : public MatcherBase<A> {
797 typedef typename A::Label Label;
798 typedef typename A::StateId StateId;
799 typedef typename A::Weight Weight;
801 NGramFstMatcher(const NGramFst<A> &fst, MatchType match_type)
804 match_type_(match_type),
805 current_loop_(false),
806 loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
807 if (match_type_ == MATCH_OUTPUT) {
808 std::swap(loop_.ilabel, loop_.olabel);
812 NGramFstMatcher(const NGramFstMatcher<A> &matcher, bool safe = false)
813 : fst_(matcher.fst_),
814 inst_(matcher.inst_),
815 match_type_(matcher.match_type_),
816 current_loop_(false),
817 loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
818 if (match_type_ == MATCH_OUTPUT) {
819 std::swap(loop_.ilabel, loop_.olabel);
823 NGramFstMatcher<A> *Copy(bool safe = false) const override {
824 return new NGramFstMatcher<A>(*this, safe);
827 MatchType Type(bool test) const override { return match_type_; }
829 const Fst<A> &GetFst() const override { return fst_; }
831 uint64 Properties(uint64 props) const override { return props; }
833 void SetState(StateId s) final {
834 fst_.GetImpl()->SetInstFuture(s, &inst_);
835 current_loop_ = false;
838 bool Find(Label label) final {
839 const Label nolabel = kNoLabel;
841 if (label == 0 || label == nolabel) {
843 current_loop_ = true;
844 loop_.nextstate = inst_.state_;
846 // The unigram state has no epsilon arc.
847 if (inst_.state_ != 0) {
848 arc_.ilabel = arc_.olabel = 0;
849 fst_.GetImpl()->SetInstNode(&inst_);
850 arc_.nextstate = fst_.GetImpl()->context_index_.Rank1(
851 fst_.GetImpl()->context_index_.Select1(
852 fst_.GetImpl()->context_index_.Rank0(inst_.node_) - 1));
853 arc_.weight = fst_.GetImpl()->backoff_[inst_.state_];
857 current_loop_ = false;
858 const Label *start = fst_.GetImpl()->future_words_ + inst_.offset_;
859 const Label *end = start + inst_.num_futures_;
860 const Label *search = std::lower_bound(start, end, label);
861 if (search != end && *search == label) {
862 size_t state = search - start;
863 arc_.ilabel = arc_.olabel = label;
864 arc_.weight = fst_.GetImpl()->future_probs_[inst_.offset_ + state];
865 fst_.GetImpl()->SetInstContext(&inst_);
866 arc_.nextstate = fst_.GetImpl()->Transition(inst_.context_, label);
873 bool Done() const final { return !current_loop_ && done_; }
875 const Arc &Value() const final { return (current_loop_) ? loop_ : arc_; }
879 current_loop_ = false;
885 ssize_t Priority(StateId s) final { return fst_.NumArcs(s); }
888 const NGramFst<A> &fst_;
889 NGramFstInst<A> inst_;
890 MatchType match_type_; // Supplied by caller
893 bool current_loop_; // Current arc is the implicit loop
897 /*****************************************************************************/
898 // Specialization for NGramFst; see generic version in fst.h
899 // for sample usage (but use the ProdLmFst type!). This version
902 class StateIterator<NGramFst<A>> : public StateIteratorBase<A> {
904 typedef typename A::StateId StateId;
906 explicit StateIterator(const NGramFst<A> &fst)
907 : s_(0), num_states_(fst.NumStates()) {}
909 bool Done() const final { return s_ >= num_states_; }
911 StateId Value() const final { return s_; }
913 void Next() final { ++s_; }
915 void Reset() final { s_ = 0; }
922 /*****************************************************************************/
924 class ArcIterator<NGramFst<A>> : public ArcIteratorBase<A> {
927 typedef typename A::Label Label;
928 typedef typename A::StateId StateId;
929 typedef typename A::Weight Weight;
931 ArcIterator(const NGramFst<A> &fst, StateId state)
932 : lazy_(~0), impl_(fst.GetImpl()), i_(0), flags_(kArcValueFlags) {
934 impl_->SetInstFuture(state, &inst_);
935 impl_->SetInstNode(&inst_);
938 bool Done() const final {
940 ((inst_.node_ == 0) ? inst_.num_futures_ : inst_.num_futures_ + 1);
943 const Arc &Value() const final {
944 bool eps = (inst_.node_ != 0 && i_ == 0);
945 StateId state = (inst_.node_ == 0) ? i_ : i_ - 1;
946 if (flags_ & lazy_ & (kArcILabelValue | kArcOLabelValue)) {
947 arc_.ilabel = arc_.olabel =
948 eps ? 0 : impl_->future_words_[inst_.offset_ + state];
949 lazy_ &= ~(kArcILabelValue | kArcOLabelValue);
951 if (flags_ & lazy_ & kArcNextStateValue) {
954 impl_->context_index_.Rank1(impl_->context_index_.Select1(
955 impl_->context_index_.Rank0(inst_.node_) - 1));
957 if (lazy_ & kArcNextStateValue) {
958 impl_->SetInstContext(&inst_); // first time only.
960 arc_.nextstate = impl_->Transition(
961 inst_.context_, impl_->future_words_[inst_.offset_ + state]);
963 lazy_ &= ~kArcNextStateValue;
965 if (flags_ & lazy_ & kArcWeightValue) {
966 arc_.weight = eps ? impl_->backoff_[inst_.state_]
967 : impl_->future_probs_[inst_.offset_ + state];
968 lazy_ &= ~kArcWeightValue;
978 size_t Position() const final { return i_; }
985 void Seek(size_t a) final {
992 uint32 Flags() const final { return flags_; }
994 void SetFlags(uint32 flags, uint32 mask) final {
996 flags_ |= (flags & kArcValueFlags);
1001 mutable uint32 lazy_;
1002 const internal::NGramFstImpl<A> *impl_; // Borrowed reference.
1003 mutable NGramFstInst<A> inst_;
1010 #endif // FST_EXTENSIONS_NGRAM_NGRAM_FST_H_