532f458e53429112ef55ad2503f7c7e02eebea4b
[platform/upstream/openfst.git] / src / include / fst / extensions / ngram / ngram-fst.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // NgramFst implements a n-gram language model based upon the LOUDS data
5 // structure.  Please refer to "Unary Data Structures for Language Models"
6 // http://research.google.com/pubs/archive/37218.pdf
7
8 #ifndef FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
9 #define FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
10
11 #include <stddef.h>
12 #include <string.h>
13 #include <algorithm>
14 #include <iostream>
15 #include <string>
16 #include <utility>
17 #include <vector>
18
19 #include <fst/compat.h>
20 #include <fst/log.h>
21 #include <fstream>
22 #include <fst/extensions/ngram/bitmap-index.h>
23 #include <fst/fstlib.h>
24 #include <fst/mapped-file.h>
25
26 namespace fst {
27 template <class A>
28 class NGramFst;
29 template <class A>
30 class NGramFstMatcher;
31
32 // Instance data containing mutable state for bookkeeping repeated access to
33 // the same state.
34 template <class A>
35 struct NGramFstInst {
36   typedef typename A::Label Label;
37   typedef typename A::StateId StateId;
38   typedef typename A::Weight Weight;
39   StateId state_;
40   size_t num_futures_;
41   size_t offset_;
42   size_t node_;
43   StateId node_state_;
44   std::vector<Label> context_;
45   StateId context_state_;
46   NGramFstInst()
47       : state_(kNoStateId),
48         node_state_(kNoStateId),
49         context_state_(kNoStateId) {}
50 };
51
52 namespace internal {
53
54 // Implementation class for LOUDS based NgramFst interface.
55 template <class A>
56 class NGramFstImpl : public FstImpl<A> {
57   using FstImpl<A>::SetInputSymbols;
58   using FstImpl<A>::SetOutputSymbols;
59   using FstImpl<A>::SetType;
60   using FstImpl<A>::WriteHeader;
61
62   friend class ArcIterator<NGramFst<A>>;
63   friend class NGramFstMatcher<A>;
64
65  public:
66   using FstImpl<A>::InputSymbols;
67   using FstImpl<A>::SetProperties;
68   using FstImpl<A>::Properties;
69
70   typedef A Arc;
71   typedef typename A::Label Label;
72   typedef typename A::StateId StateId;
73   typedef typename A::Weight Weight;
74
75   NGramFstImpl() {
76     SetType("ngram");
77     SetInputSymbols(nullptr);
78     SetOutputSymbols(nullptr);
79     SetProperties(kStaticProperties);
80   }
81
82   NGramFstImpl(const Fst<A> &fst, std::vector<StateId> *order_out);
83
84   explicit NGramFstImpl(const Fst<A> &fst) : NGramFstImpl(fst, nullptr) {}
85
86   NGramFstImpl(const NGramFstImpl &other) {
87     FSTERROR() << "Copying NGramFst Impls is not supported, use safe = false.";
88     SetProperties(kError, kError);
89   }
90
91   ~NGramFstImpl() override {
92     if (owned_) {
93       delete[] data_;
94     }
95   }
96
97   static NGramFstImpl<A> *Read(std::istream &strm,  // NOLINT
98                                const FstReadOptions &opts) {
99     NGramFstImpl<A> *impl = new NGramFstImpl();
100     FstHeader hdr;
101     if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return 0;
102     uint64 num_states, num_futures, num_final;
103     const size_t offset =
104         sizeof(num_states) + sizeof(num_futures) + sizeof(num_final);
105     // Peek at num_states and num_futures to see how much more needs to be read.
106     strm.read(reinterpret_cast<char *>(&num_states), sizeof(num_states));
107     strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures));
108     strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final));
109     size_t size = Storage(num_states, num_futures, num_final);
110     MappedFile *data_region = MappedFile::Allocate(size);
111     char *data = reinterpret_cast<char *>(data_region->mutable_data());
112     // Copy num_states, num_futures and num_final back into data.
113     memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states));
114     memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures),
115            sizeof(num_futures));
116     memcpy(data + sizeof(num_states) + sizeof(num_futures),
117            reinterpret_cast<char *>(&num_final), sizeof(num_final));
118     strm.read(data + offset, size - offset);
119     if (strm.fail()) {
120       delete impl;
121       return nullptr;
122     }
123     impl->Init(data, false, data_region);
124     return impl;
125   }
126
127   bool Write(std::ostream &strm,  // NOLINT
128              const FstWriteOptions &opts) const {
129     FstHeader hdr;
130     hdr.SetStart(Start());
131     hdr.SetNumStates(num_states_);
132     WriteHeader(strm, opts, kFileVersion, &hdr);
133     strm.write(data_, StorageSize());
134     return !strm.fail();
135   }
136
137   StateId Start() const { return start_; }
138
139   Weight Final(StateId state) const {
140     if (final_index_.Get(state)) {
141       return final_probs_[final_index_.Rank1(state)];
142     } else {
143       return Weight::Zero();
144     }
145   }
146
147   size_t NumArcs(StateId state, NGramFstInst<A> *inst = nullptr) const {
148     if (inst == nullptr) {
149       const std::pair<size_t, size_t> zeros =
150           (state == 0) ? select_root_ : future_index_.Select0s(state);
151       return zeros.second - zeros.first - 1;
152     }
153     SetInstFuture(state, inst);
154     return inst->num_futures_ + ((state == 0) ? 0 : 1);
155   }
156
157   size_t NumInputEpsilons(StateId state) const {
158     // State 0 has no parent, thus no backoff.
159     if (state == 0) return 0;
160     return 1;
161   }
162
163   size_t NumOutputEpsilons(StateId state) const {
164     return NumInputEpsilons(state);
165   }
166
167   StateId NumStates() const { return num_states_; }
168
169   void InitStateIterator(StateIteratorData<A> *data) const {
170     data->base = 0;
171     data->nstates = num_states_;
172   }
173
174   static size_t Storage(uint64 num_states, uint64 num_futures,
175                         uint64 num_final) {
176     uint64 b64;
177     Weight weight;
178     Label label;
179     size_t offset =
180         sizeof(num_states) + sizeof(num_futures) + sizeof(num_final);
181     offset +=
182         sizeof(b64) * (BitmapIndex::StorageSize(num_states * 2 + 1) +
183                        BitmapIndex::StorageSize(num_futures + num_states + 1) +
184                        BitmapIndex::StorageSize(num_states));
185     offset += (num_states + 1) * sizeof(label) + num_futures * sizeof(label);
186     // Pad for alignemnt, see
187     // http://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding
188     offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
189     offset += (num_states + 1) * sizeof(weight) + num_final * sizeof(weight) +
190               (num_futures + 1) * sizeof(weight);
191     return offset;
192   }
193
194   void SetInstFuture(StateId state, NGramFstInst<A> *inst) const {
195     if (inst->state_ != state) {
196       inst->state_ = state;
197       const std::pair<size_t, size_t> zeros = future_index_.Select0s(state);
198       inst->num_futures_ = zeros.second - zeros.first - 1;
199       inst->offset_ = future_index_.Rank1(zeros.first + 1);
200     }
201   }
202
203   void SetInstNode(NGramFstInst<A> *inst) const {
204     if (inst->node_state_ != inst->state_) {
205       inst->node_state_ = inst->state_;
206       inst->node_ = context_index_.Select1(inst->state_);
207     }
208   }
209
210   void SetInstContext(NGramFstInst<A> *inst) const {
211     SetInstNode(inst);
212     if (inst->context_state_ != inst->state_) {
213       inst->context_state_ = inst->state_;
214       inst->context_.clear();
215       size_t node = inst->node_;
216       while (node != 0) {
217         inst->context_.push_back(context_words_[context_index_.Rank1(node)]);
218         node = context_index_.Select1(context_index_.Rank0(node) - 1);
219       }
220     }
221   }
222
223   // Access to the underlying representation
224   const char *GetData(size_t *data_size) const {
225     *data_size = StorageSize();
226     return data_;
227   }
228
229   void Init(const char *data, bool owned, MappedFile *file = nullptr);
230
231   const std::vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const {
232     SetInstFuture(s, inst);
233     SetInstContext(inst);
234     return inst->context_;
235   }
236
237   size_t StorageSize() const {
238     return Storage(num_states_, num_futures_, num_final_);
239   }
240
241   void GetStates(const std::vector<Label> &context,
242                  std::vector<StateId> *states) const;
243
244  private:
245   StateId Transition(const std::vector<Label> &context, Label future) const;
246
247   // Properties always true for this Fst class.
248   static const uint64 kStaticProperties =
249       kAcceptor | kIDeterministic | kODeterministic | kEpsilons | kIEpsilons |
250       kOEpsilons | kILabelSorted | kOLabelSorted | kWeighted | kCyclic |
251       kInitialAcyclic | kNotTopSorted | kAccessible | kCoAccessible |
252       kNotString | kExpanded;
253   // Current file format version.
254   static const int kFileVersion = 4;
255   // Minimum file format version supported.
256   static const int kMinFileVersion = 4;
257
258   std::unique_ptr<MappedFile> data_region_;
259   const char *data_ = nullptr;
260   bool owned_ = false;  // True if we own data_
261   StateId start_ = fst::kNoStateId;
262   uint64 num_states_ = 0;
263   uint64 num_futures_ = 0;
264   uint64 num_final_ = 0;
265   std::pair<size_t, size_t> select_root_;
266   const Label *root_children_ = nullptr;
267   // borrowed references
268   const uint64 *context_ = nullptr;
269   const uint64 *future_ = nullptr;
270   const uint64 *final_ = nullptr;
271   const Label *context_words_ = nullptr;
272   const Label *future_words_ = nullptr;
273   const Weight *backoff_ = nullptr;
274   const Weight *final_probs_ = nullptr;
275   const Weight *future_probs_ = nullptr;
276   BitmapIndex context_index_;
277   BitmapIndex future_index_;
278   BitmapIndex final_index_;
279 };
280
281 template <typename A>
282 inline void NGramFstImpl<A>::GetStates(
283     const std::vector<Label> &context,
284     std::vector<typename A::StateId> *states) const {
285   states->clear();
286   states->push_back(0);
287   typename std::vector<Label>::const_reverse_iterator cit = context.rbegin();
288   const Label *children = root_children_;
289   size_t num_children = select_root_.second - 2;
290   const Label *loc = std::lower_bound(children, children + num_children, *cit);
291   if (loc == children + num_children || *loc != *cit) return;
292   size_t node = 2 + loc - children;
293   states->push_back(context_index_.Rank1(node));
294   if (context.size() == 1) return;
295   size_t node_rank = context_index_.Rank1(node);
296   std::pair<size_t, size_t> zeros =
297       node_rank == 0 ? select_root_ : context_index_.Select0s(node_rank);
298   size_t first_child = zeros.first + 1;
299   ++cit;
300   if (context_index_.Get(first_child) != false) {
301     size_t last_child = zeros.second - 1;
302     while (cit != context.rend()) {
303       children = context_words_ + context_index_.Rank1(first_child);
304       loc = std::lower_bound(children, children + last_child - first_child + 1,
305                              *cit);
306       if (loc == children + last_child - first_child + 1 || *loc != *cit) {
307         break;
308       }
309       ++cit;
310       node = first_child + loc - children;
311       states->push_back(context_index_.Rank1(node));
312       node_rank = context_index_.Rank1(node);
313       zeros =
314           node_rank == 0 ? select_root_ : context_index_.Select0s(node_rank);
315       first_child = zeros.first + 1;
316       if (context_index_.Get(first_child) == false) break;
317       last_child = zeros.second - 1;
318     }
319   }
320 }
321
322 }  // namespace internal
323
324 /*****************************************************************************/
325 template <class A>
326 class NGramFst : public ImplToExpandedFst<internal::NGramFstImpl<A>> {
327   friend class ArcIterator<NGramFst<A>>;
328   friend class NGramFstMatcher<A>;
329
330  public:
331   typedef A Arc;
332   typedef typename A::StateId StateId;
333   typedef typename A::Label Label;
334   typedef typename A::Weight Weight;
335   typedef internal::NGramFstImpl<A> Impl;
336
337   explicit NGramFst(const Fst<A> &dst)
338       : ImplToExpandedFst<Impl>(std::make_shared<Impl>(dst, nullptr)) {}
339
340   NGramFst(const Fst<A> &fst, std::vector<StateId> *order_out)
341       : ImplToExpandedFst<Impl>(std::make_shared<Impl>(fst, order_out)) {}
342
343   // Because the NGramFstImpl is a const stateless data structure, there
344   // is never a need to do anything beside copy the reference.
345   NGramFst(const NGramFst<A> &fst, bool safe = false)
346       : ImplToExpandedFst<Impl>(fst, false) {}
347
348   NGramFst() : ImplToExpandedFst<Impl>(std::make_shared<Impl>()) {}
349
350   // Non-standard constructor to initialize NGramFst directly from data.
351   NGramFst(const char *data, bool owned)
352       : ImplToExpandedFst<Impl>(std::make_shared<Impl>()) {
353     GetMutableImpl()->Init(data, owned, nullptr);
354   }
355
356   // Get method that gets the data associated with Init().
357   const char *GetData(size_t *data_size) const {
358     return GetImpl()->GetData(data_size);
359   }
360
361   const std::vector<Label> GetContext(StateId s) const {
362     return GetImpl()->GetContext(s, &inst_);
363   }
364
365   // Consumes as much as possible of context from right to left, returns the
366   // the states corresponding to the increasingly conditioned input sequence.
367   void GetStates(const std::vector<Label> &context,
368                  std::vector<StateId> *state) const {
369     return GetImpl()->GetStates(context, state);
370   }
371
372   size_t NumArcs(StateId s) const override {
373     return GetImpl()->NumArcs(s, &inst_);
374   }
375
376   NGramFst<A> *Copy(bool safe = false) const override {
377     return new NGramFst(*this, safe);
378   }
379
380   static NGramFst<A> *Read(std::istream &strm, const FstReadOptions &opts) {
381     Impl *impl = Impl::Read(strm, opts);
382     return impl ? new NGramFst<A>(std::shared_ptr<Impl>(impl)) : nullptr;
383   }
384
385   static NGramFst<A> *Read(const string &filename) {
386     if (!filename.empty()) {
387       std::ifstream strm(filename,
388                               std::ios_base::in | std::ios_base::binary);
389       if (!strm.good()) {
390         LOG(ERROR) << "NGramFst::Read: Can't open file: " << filename;
391         return nullptr;
392       }
393       return Read(strm, FstReadOptions(filename));
394     } else {
395       return Read(std::cin, FstReadOptions("standard input"));
396     }
397   }
398
399   bool Write(std::ostream &strm, const FstWriteOptions &opts) const override {
400     return GetImpl()->Write(strm, opts);
401   }
402
403   bool Write(const string &filename) const override {
404     return Fst<A>::WriteFile(filename);
405   }
406
407   inline void InitStateIterator(StateIteratorData<A> *data) const override {
408     GetImpl()->InitStateIterator(data);
409   }
410
411   inline void InitArcIterator(StateId s,
412                               ArcIteratorData<A> *data) const override;
413
414   MatcherBase<A> *InitMatcher(MatchType match_type) const override {
415     return new NGramFstMatcher<A>(*this, match_type);
416   }
417
418   size_t StorageSize() const { return GetImpl()->StorageSize(); }
419
420   static bool HasRequiredProps(const Fst<A> &fst) {
421     int64 props =
422         kAcceptor | kIDeterministic | kILabelSorted | kIEpsilons | kAccessible;
423     return fst.Properties(props, true) == props;
424   }
425
426   static bool HasRequiredStructure(const Fst<A> &fst) {
427     if (!HasRequiredProps(fst)) {
428       return false;
429     }
430     typename A::StateId unigram = fst.Start();
431     while (true) {  // Follows epsilon arc chain to find unigram state.
432       if (unigram == fst::kNoStateId) return false;  // No unigram state.
433       typename fst::ArcIterator<Fst<A>> aiter(fst, unigram);
434       if (aiter.Done() || aiter.Value().ilabel != 0) break;
435       unigram = aiter.Value().nextstate;
436       aiter.Next();
437     }
438     // Other requirement: all states other than unigram an epsilon arc.
439     for (fst::StateIterator<Fst<A>> siter(fst); !siter.Done();
440          siter.Next()) {
441       const typename A::StateId &state = siter.Value();
442       fst::ArcIterator<Fst<A>> aiter(fst, state);
443       if (state != unigram) {
444         if (aiter.Done()) return false;
445         if (aiter.Value().ilabel != 0) return false;
446         aiter.Next();
447         if (!aiter.Done() && aiter.Value().ilabel == 0) return false;
448       }
449     }
450     return true;
451   }
452
453  private:
454   using ImplToExpandedFst<Impl, ExpandedFst<A>>::GetImpl;
455   using ImplToExpandedFst<Impl, ExpandedFst<A>>::GetMutableImpl;
456
457   explicit NGramFst(std::shared_ptr<Impl> impl)
458       : ImplToExpandedFst<Impl>(impl) {}
459
460   mutable NGramFstInst<A> inst_;
461 };
462
463 template <class A>
464 inline void NGramFst<A>::InitArcIterator(StateId s,
465                                          ArcIteratorData<A> *data) const {
466   GetImpl()->SetInstFuture(s, &inst_);
467   GetImpl()->SetInstNode(&inst_);
468   data->base = new ArcIterator<NGramFst<A>>(*this, s);
469 }
470
471 namespace internal {
472
473 template <typename A>
474 NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst,
475                               std::vector<StateId> *order_out) {
476   typedef A Arc;
477   typedef typename Arc::Label Label;
478   typedef typename Arc::Weight Weight;
479   typedef typename Arc::StateId StateId;
480   SetType("ngram");
481   SetInputSymbols(fst.InputSymbols());
482   SetOutputSymbols(fst.OutputSymbols());
483   SetProperties(kStaticProperties);
484
485   // Check basic requirements for an OpenGrm language model Fst.
486   if (!NGramFst<A>::HasRequiredProps(fst)) {
487     FSTERROR() << "NGramFst only accepts OpenGrm language models as input";
488     SetProperties(kError, kError);
489     return;
490   }
491
492   int64 num_states = CountStates(fst);
493   Label *context = new Label[num_states];
494
495   // Find the unigram state by starting from the start state, following
496   // epsilons.
497   StateId unigram = fst.Start();
498   while (1) {
499     if (unigram == kNoStateId) {
500       FSTERROR() << "Could not identify unigram state";
501       SetProperties(kError, kError);
502       return;
503     }
504     ArcIterator<Fst<A>> aiter(fst, unigram);
505     if (aiter.Done()) {
506       LOG(WARNING) << "Unigram state " << unigram << " has no arcs.";
507       break;
508     }
509     if (aiter.Value().ilabel != 0) break;
510     unigram = aiter.Value().nextstate;
511   }
512
513   // Each state's context is determined by the subtree it is under from the
514   // unigram state.
515   std::queue<std::pair<StateId, Label>> label_queue;
516   std::vector<bool> visited(num_states);
517   // Force an epsilon link to the start state.
518   label_queue.push(std::make_pair(fst.Start(), 0));
519   for (ArcIterator<Fst<A>> aiter(fst, unigram); !aiter.Done(); aiter.Next()) {
520     label_queue.push(
521         std::make_pair(aiter.Value().nextstate, aiter.Value().ilabel));
522   }
523   // investigate states in breadth first fashion to assign context words.
524   while (!label_queue.empty()) {
525     std::pair<StateId, Label> &now = label_queue.front();
526     if (!visited[now.first]) {
527       context[now.first] = now.second;
528       visited[now.first] = true;
529       for (ArcIterator<Fst<A>> aiter(fst, now.first); !aiter.Done();
530            aiter.Next()) {
531         const Arc &arc = aiter.Value();
532         if (arc.ilabel != 0) {
533           label_queue.push(std::make_pair(arc.nextstate, now.second));
534         }
535       }
536     }
537     label_queue.pop();
538   }
539   visited.clear();
540
541   // The arc from the start state should be assigned an epsilon to put it
542   // in front of the all other labels (which makes Start state 1 after
543   // unigram which is state 0).
544   context[fst.Start()] = 0;
545
546   // Build the tree of contexts fst by reversing the epsilon arcs from fst.
547   VectorFst<Arc> context_fst;
548   uint64 num_final = 0;
549   for (int i = 0; i < num_states; ++i) {
550     if (fst.Final(i) != Weight::Zero()) {
551       ++num_final;
552     }
553     context_fst.SetFinal(context_fst.AddState(), fst.Final(i));
554   }
555   context_fst.SetStart(unigram);
556   context_fst.SetInputSymbols(fst.InputSymbols());
557   context_fst.SetOutputSymbols(fst.OutputSymbols());
558   int64 num_context_arcs = 0;
559   int64 num_futures = 0;
560   for (StateIterator<Fst<A>> siter(fst); !siter.Done(); siter.Next()) {
561     const StateId &state = siter.Value();
562     num_futures += fst.NumArcs(state) - fst.NumInputEpsilons(state);
563     ArcIterator<Fst<A>> aiter(fst, state);
564     if (!aiter.Done()) {
565       const Arc &arc = aiter.Value();
566       // this arc goes from state to arc.nextstate, so create an arc from
567       // arc.nextstate to state to reverse it.
568       if (arc.ilabel == 0) {
569         context_fst.AddArc(arc.nextstate, Arc(context[state], context[state],
570                                               arc.weight, state));
571         num_context_arcs++;
572       }
573     }
574   }
575   if (num_context_arcs != context_fst.NumStates() - 1) {
576     FSTERROR() << "Number of contexts arcs != number of states - 1";
577     SetProperties(kError, kError);
578     return;
579   }
580   if (context_fst.NumStates() != num_states) {
581     FSTERROR() << "Number of contexts != number of states";
582     SetProperties(kError, kError);
583     return;
584   }
585   int64 context_props =
586       context_fst.Properties(kIDeterministic | kILabelSorted, true);
587   if (!(context_props & kIDeterministic)) {
588     FSTERROR() << "Input Fst is not structured properly";
589     SetProperties(kError, kError);
590     return;
591   }
592   if (!(context_props & kILabelSorted)) {
593     ArcSort(&context_fst, ILabelCompare<Arc>());
594   }
595
596   delete[] context;
597
598   uint64 b64;
599   Weight weight;
600   Label label = kNoLabel;
601   const size_t storage = Storage(num_states, num_futures, num_final);
602   MappedFile *data_region = MappedFile::Allocate(storage);
603   char *data = reinterpret_cast<char *>(data_region->mutable_data());
604   memset(data, 0, storage);
605   size_t offset = 0;
606   memcpy(data + offset, reinterpret_cast<char *>(&num_states),
607          sizeof(num_states));
608   offset += sizeof(num_states);
609   memcpy(data + offset, reinterpret_cast<char *>(&num_futures),
610          sizeof(num_futures));
611   offset += sizeof(num_futures);
612   memcpy(data + offset, reinterpret_cast<char *>(&num_final),
613          sizeof(num_final));
614   offset += sizeof(num_final);
615   uint64 *context_bits = reinterpret_cast<uint64 *>(data + offset);
616   offset += BitmapIndex::StorageSize(num_states * 2 + 1) * sizeof(b64);
617   uint64 *future_bits = reinterpret_cast<uint64 *>(data + offset);
618   offset +=
619       BitmapIndex::StorageSize(num_futures + num_states + 1) * sizeof(b64);
620   uint64 *final_bits = reinterpret_cast<uint64 *>(data + offset);
621   offset += BitmapIndex::StorageSize(num_states) * sizeof(b64);
622   Label *context_words = reinterpret_cast<Label *>(data + offset);
623   offset += (num_states + 1) * sizeof(label);
624   Label *future_words = reinterpret_cast<Label *>(data + offset);
625   offset += num_futures * sizeof(label);
626   offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
627   Weight *backoff = reinterpret_cast<Weight *>(data + offset);
628   offset += (num_states + 1) * sizeof(weight);
629   Weight *final_probs = reinterpret_cast<Weight *>(data + offset);
630   offset += num_final * sizeof(weight);
631   Weight *future_probs = reinterpret_cast<Weight *>(data + offset);
632   int64 context_arc = 0, future_arc = 0, context_bit = 0, future_bit = 0,
633         final_bit = 0;
634
635   // pseudo-root bits
636   BitmapIndex::Set(context_bits, context_bit++);
637   ++context_bit;
638   context_words[context_arc] = label;
639   backoff[context_arc] = Weight::Zero();
640   context_arc++;
641
642   ++future_bit;
643   if (order_out) {
644     order_out->clear();
645     order_out->resize(num_states);
646   }
647
648   std::queue<StateId> context_q;
649   context_q.push(context_fst.Start());
650   StateId state_number = 0;
651   while (!context_q.empty()) {
652     const StateId &state = context_q.front();
653     if (order_out) {
654       (*order_out)[state] = state_number;
655     }
656
657     const Weight final_weight = context_fst.Final(state);
658     if (final_weight != Weight::Zero()) {
659       BitmapIndex::Set(final_bits, state_number);
660       final_probs[final_bit] = final_weight;
661       ++final_bit;
662     }
663
664     for (ArcIterator<VectorFst<A>> aiter(context_fst, state); !aiter.Done();
665          aiter.Next()) {
666       const Arc &arc = aiter.Value();
667       context_words[context_arc] = arc.ilabel;
668       backoff[context_arc] = arc.weight;
669       ++context_arc;
670       BitmapIndex::Set(context_bits, context_bit++);
671       context_q.push(arc.nextstate);
672     }
673     ++context_bit;
674
675     for (ArcIterator<Fst<A>> aiter(fst, state); !aiter.Done(); aiter.Next()) {
676       const Arc &arc = aiter.Value();
677       if (arc.ilabel != 0) {
678         future_words[future_arc] = arc.ilabel;
679         future_probs[future_arc] = arc.weight;
680         ++future_arc;
681         BitmapIndex::Set(future_bits, future_bit++);
682       }
683     }
684     ++future_bit;
685     ++state_number;
686     context_q.pop();
687   }
688
689   if ((state_number != num_states) || (context_bit != num_states * 2 + 1) ||
690       (context_arc != num_states) || (future_arc != num_futures) ||
691       (future_bit != num_futures + num_states + 1) ||
692       (final_bit != num_final)) {
693     FSTERROR() << "Structure problems detected during construction";
694     SetProperties(kError, kError);
695     return;
696   }
697
698   Init(data, false, data_region);
699 }
700
701 template <typename A>
702 inline void NGramFstImpl<A>::Init(const char *data, bool owned,
703                                   MappedFile *data_region) {
704   if (owned_) {
705     delete[] data_;
706   }
707   data_region_.reset(data_region);
708   owned_ = owned;
709   data_ = data;
710   size_t offset = 0;
711   num_states_ = *(reinterpret_cast<const uint64 *>(data_ + offset));
712   offset += sizeof(num_states_);
713   num_futures_ = *(reinterpret_cast<const uint64 *>(data_ + offset));
714   offset += sizeof(num_futures_);
715   num_final_ = *(reinterpret_cast<const uint64 *>(data_ + offset));
716   offset += sizeof(num_final_);
717   uint64 bits;
718   size_t context_bits = num_states_ * 2 + 1;
719   size_t future_bits = num_futures_ + num_states_ + 1;
720   context_ = reinterpret_cast<const uint64 *>(data_ + offset);
721   offset += BitmapIndex::StorageSize(context_bits) * sizeof(bits);
722   future_ = reinterpret_cast<const uint64 *>(data_ + offset);
723   offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits);
724   final_ = reinterpret_cast<const uint64 *>(data_ + offset);
725   offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits);
726   context_words_ = reinterpret_cast<const Label *>(data_ + offset);
727   offset += (num_states_ + 1) * sizeof(*context_words_);
728   future_words_ = reinterpret_cast<const Label *>(data_ + offset);
729   offset += num_futures_ * sizeof(*future_words_);
730   offset = (offset + sizeof(*backoff_) - 1) & ~(sizeof(*backoff_) - 1);
731   backoff_ = reinterpret_cast<const Weight *>(data_ + offset);
732   offset += (num_states_ + 1) * sizeof(*backoff_);
733   final_probs_ = reinterpret_cast<const Weight *>(data_ + offset);
734   offset += num_final_ * sizeof(*final_probs_);
735   future_probs_ = reinterpret_cast<const Weight *>(data_ + offset);
736
737   context_index_.BuildIndex(context_, context_bits);
738   future_index_.BuildIndex(future_, future_bits);
739   final_index_.BuildIndex(final_, num_states_);
740
741   select_root_ = context_index_.Select0s(0);
742   if (context_index_.Rank1(0) != 0 || select_root_.first != 1 ||
743       context_index_.Get(2) == false) {
744     FSTERROR() << "Malformed file";
745     SetProperties(kError, kError);
746     return;
747   }
748   root_children_ = context_words_ + context_index_.Rank1(2);
749   start_ = 1;
750 }
751
752 template <typename A>
753 inline typename A::StateId NGramFstImpl<A>::Transition(
754     const std::vector<Label> &context, Label future) const {
755   const Label *children = root_children_;
756   size_t num_children = select_root_.second - 2;
757   const Label *loc =
758       std::lower_bound(children, children + num_children, future);
759   if (loc == children + num_children || *loc != future) {
760     return context_index_.Rank1(0);
761   }
762   size_t node = 2 + loc - children;
763   size_t node_rank = context_index_.Rank1(node);
764   std::pair<size_t, size_t> zeros =
765       (node_rank == 0) ? select_root_ : context_index_.Select0s(node_rank);
766   size_t first_child = zeros.first + 1;
767   if (context_index_.Get(first_child) == false) {
768     return context_index_.Rank1(node);
769   }
770   size_t last_child = zeros.second - 1;
771   for (int word = context.size() - 1; word >= 0; --word) {
772     children = context_words_ + context_index_.Rank1(first_child);
773     loc = std::lower_bound(children, children + last_child - first_child + 1,
774                            context[word]);
775     if (loc == children + last_child - first_child + 1 ||
776         *loc != context[word]) {
777       break;
778     }
779     node = first_child + loc - children;
780     node_rank = context_index_.Rank1(node);
781     zeros =
782         (node_rank == 0) ? select_root_ : context_index_.Select0s(node_rank);
783     first_child = zeros.first + 1;
784     if (context_index_.Get(first_child) == false) break;
785     last_child = zeros.second - 1;
786   }
787   return context_index_.Rank1(node);
788 }
789
790 }  // namespace internal
791
792 /*****************************************************************************/
793 template <class A>
794 class NGramFstMatcher : public MatcherBase<A> {
795  public:
796   typedef A Arc;
797   typedef typename A::Label Label;
798   typedef typename A::StateId StateId;
799   typedef typename A::Weight Weight;
800
801   NGramFstMatcher(const NGramFst<A> &fst, MatchType match_type)
802       : fst_(fst),
803         inst_(fst.inst_),
804         match_type_(match_type),
805         current_loop_(false),
806         loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
807     if (match_type_ == MATCH_OUTPUT) {
808       std::swap(loop_.ilabel, loop_.olabel);
809     }
810   }
811
812   NGramFstMatcher(const NGramFstMatcher<A> &matcher, bool safe = false)
813       : fst_(matcher.fst_),
814         inst_(matcher.inst_),
815         match_type_(matcher.match_type_),
816         current_loop_(false),
817         loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
818     if (match_type_ == MATCH_OUTPUT) {
819       std::swap(loop_.ilabel, loop_.olabel);
820     }
821   }
822
823   NGramFstMatcher<A> *Copy(bool safe = false) const override {
824     return new NGramFstMatcher<A>(*this, safe);
825   }
826
827   MatchType Type(bool test) const override { return match_type_; }
828
829   const Fst<A> &GetFst() const override { return fst_; }
830
831   uint64 Properties(uint64 props) const override { return props; }
832
833   void SetState(StateId s) final {
834     fst_.GetImpl()->SetInstFuture(s, &inst_);
835     current_loop_ = false;
836   }
837
838   bool Find(Label label) final {
839     const Label nolabel = kNoLabel;
840     done_ = true;
841     if (label == 0 || label == nolabel) {
842       if (label == 0) {
843         current_loop_ = true;
844         loop_.nextstate = inst_.state_;
845       }
846       // The unigram state has no epsilon arc.
847       if (inst_.state_ != 0) {
848         arc_.ilabel = arc_.olabel = 0;
849         fst_.GetImpl()->SetInstNode(&inst_);
850         arc_.nextstate = fst_.GetImpl()->context_index_.Rank1(
851             fst_.GetImpl()->context_index_.Select1(
852                 fst_.GetImpl()->context_index_.Rank0(inst_.node_) - 1));
853         arc_.weight = fst_.GetImpl()->backoff_[inst_.state_];
854         done_ = false;
855       }
856     } else {
857       current_loop_ = false;
858       const Label *start = fst_.GetImpl()->future_words_ + inst_.offset_;
859       const Label *end = start + inst_.num_futures_;
860       const Label *search = std::lower_bound(start, end, label);
861       if (search != end && *search == label) {
862         size_t state = search - start;
863         arc_.ilabel = arc_.olabel = label;
864         arc_.weight = fst_.GetImpl()->future_probs_[inst_.offset_ + state];
865         fst_.GetImpl()->SetInstContext(&inst_);
866         arc_.nextstate = fst_.GetImpl()->Transition(inst_.context_, label);
867         done_ = false;
868       }
869     }
870     return !Done();
871   }
872
873   bool Done() const final { return !current_loop_ && done_; }
874
875   const Arc &Value() const final { return (current_loop_) ? loop_ : arc_; }
876
877   void Next() final {
878     if (current_loop_) {
879       current_loop_ = false;
880     } else {
881       done_ = true;
882     }
883   }
884
885   ssize_t Priority(StateId s) final { return fst_.NumArcs(s); }
886
887  private:
888   const NGramFst<A> &fst_;
889   NGramFstInst<A> inst_;
890   MatchType match_type_;  // Supplied by caller
891   bool done_;
892   Arc arc_;
893   bool current_loop_;  // Current arc is the implicit loop
894   Arc loop_;
895 };
896
897 /*****************************************************************************/
898 // Specialization for NGramFst; see generic version in fst.h
899 // for sample usage (but use the ProdLmFst type!). This version
900 // should inline.
901 template <class A>
902 class StateIterator<NGramFst<A>> : public StateIteratorBase<A> {
903  public:
904   typedef typename A::StateId StateId;
905
906   explicit StateIterator(const NGramFst<A> &fst)
907       : s_(0), num_states_(fst.NumStates()) {}
908
909   bool Done() const final { return s_ >= num_states_; }
910
911   StateId Value() const final { return s_; }
912
913   void Next() final { ++s_; }
914
915   void Reset() final { s_ = 0; }
916
917  private:
918   StateId s_;
919   StateId num_states_;
920 };
921
922 /*****************************************************************************/
923 template <class A>
924 class ArcIterator<NGramFst<A>> : public ArcIteratorBase<A> {
925  public:
926   typedef A Arc;
927   typedef typename A::Label Label;
928   typedef typename A::StateId StateId;
929   typedef typename A::Weight Weight;
930
931   ArcIterator(const NGramFst<A> &fst, StateId state)
932       : lazy_(~0), impl_(fst.GetImpl()), i_(0), flags_(kArcValueFlags) {
933     inst_ = fst.inst_;
934     impl_->SetInstFuture(state, &inst_);
935     impl_->SetInstNode(&inst_);
936   }
937
938   bool Done() const final {
939     return i_ >=
940            ((inst_.node_ == 0) ? inst_.num_futures_ : inst_.num_futures_ + 1);
941   }
942
943   const Arc &Value() const final {
944     bool eps = (inst_.node_ != 0 && i_ == 0);
945     StateId state = (inst_.node_ == 0) ? i_ : i_ - 1;
946     if (flags_ & lazy_ & (kArcILabelValue | kArcOLabelValue)) {
947       arc_.ilabel = arc_.olabel =
948           eps ? 0 : impl_->future_words_[inst_.offset_ + state];
949       lazy_ &= ~(kArcILabelValue | kArcOLabelValue);
950     }
951     if (flags_ & lazy_ & kArcNextStateValue) {
952       if (eps) {
953         arc_.nextstate =
954             impl_->context_index_.Rank1(impl_->context_index_.Select1(
955                 impl_->context_index_.Rank0(inst_.node_) - 1));
956       } else {
957         if (lazy_ & kArcNextStateValue) {
958           impl_->SetInstContext(&inst_);  // first time only.
959         }
960         arc_.nextstate = impl_->Transition(
961             inst_.context_, impl_->future_words_[inst_.offset_ + state]);
962       }
963       lazy_ &= ~kArcNextStateValue;
964     }
965     if (flags_ & lazy_ & kArcWeightValue) {
966       arc_.weight = eps ? impl_->backoff_[inst_.state_]
967                         : impl_->future_probs_[inst_.offset_ + state];
968       lazy_ &= ~kArcWeightValue;
969     }
970     return arc_;
971   }
972
973   void Next() final {
974     ++i_;
975     lazy_ = ~0;
976   }
977
978   size_t Position() const final { return i_; }
979
980   void Reset() final {
981     i_ = 0;
982     lazy_ = ~0;
983   }
984
985   void Seek(size_t a) final {
986     if (i_ != a) {
987       i_ = a;
988       lazy_ = ~0;
989     }
990   }
991
992   uint32 Flags() const final { return flags_; }
993
994   void SetFlags(uint32 flags, uint32 mask) final {
995     flags_ &= ~mask;
996     flags_ |= (flags & kArcValueFlags);
997   }
998
999  private:
1000   mutable Arc arc_;
1001   mutable uint32 lazy_;
1002   const internal::NGramFstImpl<A> *impl_;  // Borrowed reference.
1003   mutable NGramFstInst<A> inst_;
1004
1005   size_t i_;
1006   uint32 flags_;
1007 };
1008
1009 }  // namespace fst
1010 #endif  // FST_EXTENSIONS_NGRAM_NGRAM_FST_H_