a2c19e73b5008891046132474d53ff44469b6222
[platform/upstream/openfst.git] / src / include / fst / extensions / linear / linear-fst.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Classes for building, storing and representing log-linear models as FSTs.
5
6 #ifndef FST_EXTENSIONS_LINEAR_LINEAR_FST_H_
7 #define FST_EXTENSIONS_LINEAR_LINEAR_FST_H_
8
9 #include <algorithm>
10 #include <iostream>
11 #include <memory>
12 #include <vector>
13
14 #include <fst/compat.h>
15 #include <fst/log.h>
16 #include <fst/extensions/pdt/collection.h>
17 #include <fst/bi-table.h>
18 #include <fst/cache.h>
19 #include <fstream>
20 #include <fst/fst.h>
21 #include <fst/matcher.h>
22 #include <fst/symbol-table.h>
23
24 #include <fst/extensions/linear/linear-fst-data.h>
25
26 namespace fst {
27
28 // Forward declaration of the specialized matcher for both
29 // LinearTaggerFst and LinearClassifierFst.
30 template <class F>
31 class LinearFstMatcherTpl;
32
33 namespace internal {
34
35 // Implementation class for on-the-fly generated LinearTaggerFst with
36 // special optimization in matching.
37 template <class A>
38 class LinearTaggerFstImpl : public CacheImpl<A> {
39  public:
40   using FstImpl<A>::SetType;
41   using FstImpl<A>::SetProperties;
42   using FstImpl<A>::SetInputSymbols;
43   using FstImpl<A>::SetOutputSymbols;
44   using FstImpl<A>::WriteHeader;
45
46   using CacheBaseImpl<CacheState<A>>::PushArc;
47   using CacheBaseImpl<CacheState<A>>::HasArcs;
48   using CacheBaseImpl<CacheState<A>>::HasFinal;
49   using CacheBaseImpl<CacheState<A>>::HasStart;
50   using CacheBaseImpl<CacheState<A>>::SetArcs;
51   using CacheBaseImpl<CacheState<A>>::SetFinal;
52   using CacheBaseImpl<CacheState<A>>::SetStart;
53
54   typedef A Arc;
55   typedef typename A::Label Label;
56   typedef typename A::Weight Weight;
57   typedef typename A::StateId StateId;
58   typedef typename Collection<StateId, Label>::SetIterator NGramIterator;
59
60   // Constructs an empty FST by default.
61   LinearTaggerFstImpl()
62       : CacheImpl<A>(CacheOptions()),
63         data_(std::make_shared<LinearFstData<A>>()),
64         delay_(0) {
65     SetType("linear-tagger");
66   }
67
68   // Constructs the FST with given data storage and symbol
69   // tables.
70   //
71   // TODO(wuke): when there is no constraint on output we can delay
72   // less than `data->MaxFutureSize` positions.
73   LinearTaggerFstImpl(const LinearFstData<Arc> *data, const SymbolTable *isyms,
74                       const SymbolTable *osyms, CacheOptions opts)
75       : CacheImpl<A>(opts), data_(data), delay_(data->MaxFutureSize()) {
76     SetType("linear-tagger");
77     SetProperties(kILabelSorted, kFstProperties);
78     SetInputSymbols(isyms);
79     SetOutputSymbols(osyms);
80     ReserveStubSpace();
81   }
82
83   // Copy by sharing the underlying data storage.
84   LinearTaggerFstImpl(const LinearTaggerFstImpl &impl)
85       : CacheImpl<A>(impl), data_(impl.data_), delay_(impl.delay_) {
86     SetType("linear-tagger");
87     SetProperties(impl.Properties(), kCopyProperties);
88     SetInputSymbols(impl.InputSymbols());
89     SetOutputSymbols(impl.OutputSymbols());
90     ReserveStubSpace();
91   }
92
93   StateId Start() {
94     if (!HasStart()) {
95       StateId start = FindStartState();
96       SetStart(start);
97     }
98     return CacheImpl<A>::Start();
99   }
100
101   Weight Final(StateId s) {
102     if (!HasFinal(s)) {
103       state_stub_.clear();
104       FillState(s, &state_stub_);
105       if (CanBeFinal(state_stub_))
106         SetFinal(s, data_->FinalWeight(InternalBegin(state_stub_),
107                                        InternalEnd(state_stub_)));
108       else
109         SetFinal(s, Weight::Zero());
110     }
111     return CacheImpl<A>::Final(s);
112   }
113
114   size_t NumArcs(StateId s) {
115     if (!HasArcs(s)) Expand(s);
116     return CacheImpl<A>::NumArcs(s);
117   }
118
119   size_t NumInputEpsilons(StateId s) {
120     if (!HasArcs(s)) Expand(s);
121     return CacheImpl<A>::NumInputEpsilons(s);
122   }
123
124   size_t NumOutputEpsilons(StateId s) {
125     if (!HasArcs(s)) Expand(s);
126     return CacheImpl<A>::NumOutputEpsilons(s);
127   }
128
129   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
130     if (!HasArcs(s)) Expand(s);
131     CacheImpl<A>::InitArcIterator(s, data);
132   }
133
134   // Computes the outgoing transitions from a state, creating new
135   // destination states as needed.
136   void Expand(StateId s);
137
138   // Appends to `arcs` all out-going arcs from state `s` that matches `label` as
139   // the input label.
140   void MatchInput(StateId s, Label ilabel, std::vector<Arc> *arcs);
141
142   static LinearTaggerFstImpl *Read(std::istream &strm,
143                                    const FstReadOptions &opts);
144
145   bool Write(std::ostream &strm,  // NOLINT
146              const FstWriteOptions &opts) const {
147     FstHeader header;
148     header.SetStart(kNoStateId);
149     WriteHeader(strm, opts, kFileVersion, &header);
150     data_->Write(strm);
151     if (!strm) {
152       LOG(ERROR) << "LinearTaggerFst::Write: Write failed: " << opts.source;
153       return false;
154     }
155     return true;
156   }
157
158  private:
159   static const int kMinFileVersion;
160   static const int kFileVersion;
161
162   // A collection of functions to access parts of the state tuple. A
163   // state tuple is a vector of `Label`s with two parts:
164   //   [buffer] [internal].
165   //
166   // - [buffer] is a buffer of observed input labels with length
167   // `delay_`. `LinearFstData<A>::kStartOfSentence`
168   // (resp. `LinearFstData<A>::kEndOfSentence`) are used as
169   // paddings when the buffer has fewer than `delay_` elements, which
170   // can only appear as the prefix (resp. suffix) of the buffer.
171   //
172   // - [internal] is the internal state tuple for `LinearFstData`
173   typename std::vector<Label>::const_iterator BufferBegin(
174       const std::vector<Label> &state) const {
175     return state.begin();
176   }
177
178   typename std::vector<Label>::const_iterator BufferEnd(
179       const std::vector<Label> &state) const {
180     return state.begin() + delay_;
181   }
182
183   typename std::vector<Label>::const_iterator InternalBegin(
184       const std::vector<Label> &state) const {
185     return state.begin() + delay_;
186   }
187
188   typename std::vector<Label>::const_iterator InternalEnd(
189       const std::vector<Label> &state) const {
190     return state.end();
191   }
192
193   // The size of state tuples are fixed, reserve them in stubs
194   void ReserveStubSpace() {
195     state_stub_.reserve(delay_ + data_->NumGroups());
196     next_stub_.reserve(delay_ + data_->NumGroups());
197   }
198
199   // Computes the start state tuple and maps it to the start state id.
200   StateId FindStartState() {
201     // Empty buffer with start-of-sentence paddings
202     state_stub_.clear();
203     state_stub_.resize(delay_, LinearFstData<A>::kStartOfSentence);
204     // Append internal states
205     data_->EncodeStartState(&state_stub_);
206     return FindState(state_stub_);
207   }
208
209   // Tests whether the buffer in `(begin, end)` is empty.
210   bool IsEmptyBuffer(typename std::vector<Label>::const_iterator begin,
211                      typename std::vector<Label>::const_iterator end) const {
212     // The following is guanranteed by `ShiftBuffer()`:
213     // - buffer[i] == LinearFstData<A>::kEndOfSentence =>
214     //       buffer[i+x] == LinearFstData<A>::kEndOfSentence
215     // - buffer[i] == LinearFstData<A>::kStartOfSentence =>
216     //       buffer[i-x] == LinearFstData<A>::kStartOfSentence
217     return delay_ == 0 || *(end - 1) == LinearFstData<A>::kStartOfSentence ||
218            *begin == LinearFstData<A>::kEndOfSentence;
219   }
220
221   // Tests whether the given state tuple can be a final state. A state
222   // is final iff there is no observed input in the buffer.
223   bool CanBeFinal(const std::vector<Label> &state) {
224     return IsEmptyBuffer(BufferBegin(state), BufferEnd(state));
225   }
226
227   // Finds state corresponding to an n-gram. Creates new state if n-gram not
228   // found.
229   StateId FindState(const std::vector<Label> &ngram) {
230     StateId sparse = ngrams_.FindId(ngram, true);
231     StateId dense = condensed_.FindId(sparse, true);
232     return dense;
233   }
234
235   // Appends after `output` the state tuple corresponding to the state id. The
236   // state id must exist.
237   void FillState(StateId s, std::vector<Label> *output) {
238     s = condensed_.FindEntry(s);
239     for (NGramIterator it = ngrams_.FindSet(s); !it.Done(); it.Next()) {
240       Label label = it.Element();
241       output->push_back(label);
242     }
243   }
244
245   // Shifts the buffer in `state` by appending `ilabel` and popping
246   // the one in the front as the return value. `next_stub_` is a
247   // shifted buffer of size `delay_` where the first `delay_ - 1`
248   // elements are the last `delay_ - 1` elements in the buffer of
249   // `state`. The last (if any) element in `next_stub_` will be
250   // `ilabel` after the call returns.
251   Label ShiftBuffer(const std::vector<Label> &state, Label ilabel,
252                     std::vector<Label> *next_stub_);
253
254   // Builds an arc from state tuple `state` consuming `ilabel` and
255   // `olabel`. `next_stub_` is the buffer filled in `ShiftBuffer`.
256   Arc MakeArc(const std::vector<Label> &state, Label ilabel, Label olabel,
257               std::vector<Label> *next_stub_);
258
259   // Expands arcs from state `s`, equivalent to state tuple `state`,
260   // with input `ilabel`. `next_stub_` is the buffer filled in
261   // `ShiftBuffer`.
262   void ExpandArcs(StateId s, const std::vector<Label> &state, Label ilabel,
263                   std::vector<Label> *next_stub_);
264
265   // Appends arcs from state `s`, equivalent to state tuple `state`,
266   // with input `ilabel` to `arcs`. `next_stub_` is the buffer filled
267   // in `ShiftBuffer`.
268   void AppendArcs(StateId s, const std::vector<Label> &state, Label ilabel,
269                   std::vector<Label> *next_stub_, std::vector<Arc> *arcs);
270
271   std::shared_ptr<const LinearFstData<A>> data_;
272   size_t delay_;
273   // Mapping from internal state tuple to *non-consecutive* ids
274   Collection<StateId, Label> ngrams_;
275   // Mapping from non-consecutive id to actual state id
276   CompactHashBiTable<StateId, StateId, std::hash<StateId>> condensed_;
277   // Two frequently used vectors, reuse to avoid repeated heap
278   // allocation
279   std::vector<Label> state_stub_, next_stub_;
280
281   LinearTaggerFstImpl &operator=(const LinearTaggerFstImpl &) = delete;
282 };
283
284 template <class A>
285 const int LinearTaggerFstImpl<A>::kMinFileVersion = 1;
286
287 template <class A>
288 const int LinearTaggerFstImpl<A>::kFileVersion = 1;
289
290 template <class A>
291 inline typename A::Label LinearTaggerFstImpl<A>::ShiftBuffer(
292     const std::vector<Label> &state, Label ilabel,
293     std::vector<Label> *next_stub_) {
294   DCHECK(ilabel > 0 || ilabel == LinearFstData<A>::kEndOfSentence);
295   if (delay_ == 0) {
296     DCHECK_GT(ilabel, 0);
297     return ilabel;
298   } else {
299     (*next_stub_)[BufferEnd(*next_stub_) - next_stub_->begin() - 1] = ilabel;
300     return *BufferBegin(state);
301   }
302 }
303
304 template <class A>
305 inline A LinearTaggerFstImpl<A>::MakeArc(const std::vector<Label> &state,
306                                          Label ilabel, Label olabel,
307                                          std::vector<Label> *next_stub_) {
308   DCHECK(ilabel > 0 || ilabel == LinearFstData<A>::kEndOfSentence);
309   DCHECK(olabel > 0 || olabel == LinearFstData<A>::kStartOfSentence);
310   Weight weight(Weight::One());
311   data_->TakeTransition(BufferEnd(state), InternalBegin(state),
312                         InternalEnd(state), ilabel, olabel, next_stub_,
313                         &weight);
314   StateId nextstate = FindState(*next_stub_);
315   // Restore `next_stub_` to its size before the call
316   next_stub_->resize(delay_);
317   // In the actual arc, we use epsilons instead of boundaries.
318   return A(ilabel == LinearFstData<A>::kEndOfSentence ? 0 : ilabel,
319            olabel == LinearFstData<A>::kStartOfSentence ? 0 : olabel, weight,
320            nextstate);
321 }
322
323 template <class A>
324 inline void LinearTaggerFstImpl<A>::ExpandArcs(StateId s,
325                                                const std::vector<Label> &state,
326                                                Label ilabel,
327                                                std::vector<Label> *next_stub_) {
328   // Input label to constrain the output with, observed `delay_` steps
329   // back. `ilabel` is the input label to be put on the arc, which
330   // fires features.
331   Label obs_ilabel = ShiftBuffer(state, ilabel, next_stub_);
332   if (obs_ilabel == LinearFstData<A>::kStartOfSentence) {
333     // This happens when input is shorter than `delay_`.
334     PushArc(s, MakeArc(state, ilabel, LinearFstData<A>::kStartOfSentence,
335                        next_stub_));
336   } else {
337     std::pair<typename std::vector<typename A::Label>::const_iterator,
338               typename std::vector<typename A::Label>::const_iterator> range =
339         data_->PossibleOutputLabels(obs_ilabel);
340     for (typename std::vector<typename A::Label>::const_iterator it =
341              range.first;
342          it != range.second; ++it)
343       PushArc(s, MakeArc(state, ilabel, *it, next_stub_));
344   }
345 }
346
347 // TODO(wuke): this has much in duplicate with `ExpandArcs()`
348 template <class A>
349 inline void LinearTaggerFstImpl<A>::AppendArcs(StateId /*s*/,
350                                                const std::vector<Label> &state,
351                                                Label ilabel,
352                                                std::vector<Label> *next_stub_,
353                                                std::vector<Arc> *arcs) {
354   // Input label to constrain the output with, observed `delay_` steps
355   // back. `ilabel` is the input label to be put on the arc, which
356   // fires features.
357   Label obs_ilabel = ShiftBuffer(state, ilabel, next_stub_);
358   if (obs_ilabel == LinearFstData<A>::kStartOfSentence) {
359     // This happens when input is shorter than `delay_`.
360     arcs->push_back(
361         MakeArc(state, ilabel, LinearFstData<A>::kStartOfSentence, next_stub_));
362   } else {
363     std::pair<typename std::vector<typename A::Label>::const_iterator,
364               typename std::vector<typename A::Label>::const_iterator> range =
365         data_->PossibleOutputLabels(obs_ilabel);
366     for (typename std::vector<typename A::Label>::const_iterator it =
367              range.first;
368          it != range.second; ++it)
369       arcs->push_back(MakeArc(state, ilabel, *it, next_stub_));
370   }
371 }
372
373 template <class A>
374 void LinearTaggerFstImpl<A>::Expand(StateId s) {
375   VLOG(3) << "Expand " << s;
376   state_stub_.clear();
377   FillState(s, &state_stub_);
378
379   // Precompute the first `delay_ - 1` elements in the buffer of
380   // next states, which are identical for different input/output.
381   next_stub_.clear();
382   next_stub_.resize(delay_);
383   if (delay_ > 0)
384     std::copy(BufferBegin(state_stub_) + 1, BufferEnd(state_stub_),
385               next_stub_.begin());
386
387   // Epsilon transition for flushing out the next observed input
388   if (!IsEmptyBuffer(BufferBegin(state_stub_), BufferEnd(state_stub_)))
389     ExpandArcs(s, state_stub_, LinearFstData<A>::kEndOfSentence, &next_stub_);
390
391   // Non-epsilon input when we haven't flushed
392   if (delay_ == 0 ||
393       *(BufferEnd(state_stub_) - 1) != LinearFstData<A>::kEndOfSentence)
394     for (Label ilabel = data_->MinInputLabel();
395          ilabel <= data_->MaxInputLabel(); ++ilabel)
396       ExpandArcs(s, state_stub_, ilabel, &next_stub_);
397
398   SetArcs(s);
399 }
400
401 template <class A>
402 void LinearTaggerFstImpl<A>::MatchInput(StateId s, Label ilabel,
403                                         std::vector<Arc> *arcs) {
404   state_stub_.clear();
405   FillState(s, &state_stub_);
406
407   // Precompute the first `delay_ - 1` elements in the buffer of
408   // next states, which are identical for different input/output.
409   next_stub_.clear();
410   next_stub_.resize(delay_);
411   if (delay_ > 0)
412     std::copy(BufferBegin(state_stub_) + 1, BufferEnd(state_stub_),
413               next_stub_.begin());
414
415   if (ilabel == 0) {
416     // Epsilon transition for flushing out the next observed input
417     if (!IsEmptyBuffer(BufferBegin(state_stub_), BufferEnd(state_stub_)))
418       AppendArcs(s, state_stub_, LinearFstData<A>::kEndOfSentence, &next_stub_,
419                  arcs);
420   } else {
421     // Non-epsilon input when we haven't flushed
422     if (delay_ == 0 ||
423         *(BufferEnd(state_stub_) - 1) != LinearFstData<A>::kEndOfSentence)
424       AppendArcs(s, state_stub_, ilabel, &next_stub_, arcs);
425   }
426 }
427
428 template <class A>
429 inline LinearTaggerFstImpl<A> *LinearTaggerFstImpl<A>::Read(
430     std::istream &strm, const FstReadOptions &opts) {  // NOLINT
431   std::unique_ptr<LinearTaggerFstImpl<A>> impl(new LinearTaggerFstImpl<A>());
432   FstHeader header;
433   if (!impl->ReadHeader(strm, opts, kMinFileVersion, &header)) {
434     return nullptr;
435   }
436   impl->data_ = std::shared_ptr<LinearFstData<A>>(LinearFstData<A>::Read(strm));
437   if (!impl->data_) {
438     return nullptr;
439   }
440   impl->delay_ = impl->data_->MaxFutureSize();
441   impl->ReserveStubSpace();
442   return impl.release();
443 }
444
445 }  // namespace internal
446
447 // This class attaches interface to implementation and handles
448 // reference counting, delegating most methods to ImplToFst.
449 template <class A>
450 class LinearTaggerFst : public ImplToFst<internal::LinearTaggerFstImpl<A>> {
451  public:
452   friend class ArcIterator<LinearTaggerFst<A>>;
453   friend class StateIterator<LinearTaggerFst<A>>;
454   friend class LinearFstMatcherTpl<LinearTaggerFst<A>>;
455
456   typedef A Arc;
457   typedef typename A::Label Label;
458   typedef typename A::Weight Weight;
459   typedef typename A::StateId StateId;
460   typedef DefaultCacheStore<A> Store;
461   typedef typename Store::State State;
462   using Impl = internal::LinearTaggerFstImpl<A>;
463
464   LinearTaggerFst() : ImplToFst<Impl>(std::make_shared<Impl>()) {}
465
466   explicit LinearTaggerFst(LinearFstData<A> *data,
467                            const SymbolTable *isyms = nullptr,
468                            const SymbolTable *osyms = nullptr,
469                            CacheOptions opts = CacheOptions())
470       : ImplToFst<Impl>(std::make_shared<Impl>(data, isyms, osyms, opts)) {}
471
472   explicit LinearTaggerFst(const Fst<A> &fst)
473       : ImplToFst<Impl>(std::make_shared<Impl>()) {
474     LOG(FATAL) << "LinearTaggerFst: no constructor from arbitrary FST.";
475   }
476
477   // See Fst<>::Copy() for doc.
478   LinearTaggerFst(const LinearTaggerFst<A> &fst, bool safe = false)
479       : ImplToFst<Impl>(fst, safe) {}
480
481   // Get a copy of this LinearTaggerFst. See Fst<>::Copy() for further doc.
482   LinearTaggerFst<A> *Copy(bool safe = false) const override {
483     return new LinearTaggerFst<A>(*this, safe);
484   }
485
486   inline void InitStateIterator(StateIteratorData<A> *data) const override;
487
488   void InitArcIterator(StateId s, ArcIteratorData<A> *data) const override {
489     GetMutableImpl()->InitArcIterator(s, data);
490   }
491
492   MatcherBase<A> *InitMatcher(MatchType match_type) const override {
493     return new LinearFstMatcherTpl<LinearTaggerFst<A>>(*this, match_type);
494   }
495
496   static LinearTaggerFst<A> *Read(const string &filename) {
497     if (!filename.empty()) {
498       std::ifstream strm(filename,
499                               std::ios_base::in | std::ios_base::binary);
500       if (!strm) {
501         LOG(ERROR) << "LinearTaggerFst::Read: Can't open file: " << filename;
502         return nullptr;
503       }
504       return Read(strm, FstReadOptions(filename));
505     } else {
506       return Read(std::cin, FstReadOptions("standard input"));
507     }
508   }
509
510   static LinearTaggerFst<A> *Read(std::istream &in,  // NOLINT
511                                   const FstReadOptions &opts) {
512     auto *impl = Impl::Read(in, opts);
513     return impl ? new LinearTaggerFst<A>(std::shared_ptr<Impl>(impl)) : nullptr;
514   }
515
516   bool Write(const string &filename) const override {
517     if (!filename.empty()) {
518       std::ofstream strm(filename,
519                                std::ios_base::out | std::ios_base::binary);
520       if (!strm) {
521         LOG(ERROR) << "LinearTaggerFst::Write: Can't open file: " << filename;
522         return false;
523       }
524       return Write(strm, FstWriteOptions(filename));
525     } else {
526       return Write(std::cout, FstWriteOptions("standard output"));
527     }
528   }
529
530   bool Write(std::ostream &strm, const FstWriteOptions &opts) const override {
531     return GetImpl()->Write(strm, opts);
532   }
533
534  private:
535   using ImplToFst<Impl>::GetImpl;
536   using ImplToFst<Impl>::GetMutableImpl;
537
538   explicit LinearTaggerFst(std::shared_ptr<Impl> impl)
539       : ImplToFst<Impl>(impl) {}
540
541   void operator=(const LinearTaggerFst<A> &fst) = delete;
542 };
543
544 // Specialization for LinearTaggerFst.
545 template <class Arc>
546 class StateIterator<LinearTaggerFst<Arc>>
547     : public CacheStateIterator<LinearTaggerFst<Arc>> {
548  public:
549   explicit StateIterator(const LinearTaggerFst<Arc> &fst)
550       : CacheStateIterator<LinearTaggerFst<Arc>>(fst, fst.GetMutableImpl()) {}
551 };
552
553 // Specialization for LinearTaggerFst.
554 template <class Arc>
555 class ArcIterator<LinearTaggerFst<Arc>>
556     : public CacheArcIterator<LinearTaggerFst<Arc>> {
557  public:
558   using StateId = typename Arc::StateId;
559
560   ArcIterator(const LinearTaggerFst<Arc> &fst, StateId s)
561       : CacheArcIterator<LinearTaggerFst<Arc>>(fst.GetMutableImpl(), s) {
562     if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
563   }
564 };
565
566 template <class Arc>
567 inline void LinearTaggerFst<Arc>::InitStateIterator(
568     StateIteratorData<Arc> *data) const {
569   data->base = new StateIterator<LinearTaggerFst<Arc>>(*this);
570 }
571
572 namespace internal {
573
574 // Implementation class for on-the-fly generated LinearClassifierFst with
575 // special optimization in matching.
576 template <class A>
577 class LinearClassifierFstImpl : public CacheImpl<A> {
578  public:
579   using FstImpl<A>::SetType;
580   using FstImpl<A>::SetProperties;
581   using FstImpl<A>::SetInputSymbols;
582   using FstImpl<A>::SetOutputSymbols;
583   using FstImpl<A>::WriteHeader;
584
585   using CacheBaseImpl<CacheState<A>>::PushArc;
586   using CacheBaseImpl<CacheState<A>>::HasArcs;
587   using CacheBaseImpl<CacheState<A>>::HasFinal;
588   using CacheBaseImpl<CacheState<A>>::HasStart;
589   using CacheBaseImpl<CacheState<A>>::SetArcs;
590   using CacheBaseImpl<CacheState<A>>::SetFinal;
591   using CacheBaseImpl<CacheState<A>>::SetStart;
592
593   typedef A Arc;
594   typedef typename A::Label Label;
595   typedef typename A::Weight Weight;
596   typedef typename A::StateId StateId;
597   typedef typename Collection<StateId, Label>::SetIterator NGramIterator;
598
599   // Constructs an empty FST by default.
600   LinearClassifierFstImpl()
601       : CacheImpl<A>(CacheOptions()),
602         data_(std::make_shared<LinearFstData<A>>()) {
603     SetType("linear-classifier");
604     num_classes_ = 0;
605     num_groups_ = 0;
606   }
607
608   // Constructs the FST with given data storage, number of classes and
609   // symbol tables.
610   LinearClassifierFstImpl(const LinearFstData<Arc> *data, size_t num_classes,
611                           const SymbolTable *isyms, const SymbolTable *osyms,
612                           CacheOptions opts)
613       : CacheImpl<A>(opts),
614         data_(data),
615         num_classes_(num_classes),
616         num_groups_(data_->NumGroups() / num_classes_) {
617     SetType("linear-classifier");
618     SetProperties(kILabelSorted, kFstProperties);
619     SetInputSymbols(isyms);
620     SetOutputSymbols(osyms);
621     ReserveStubSpace();
622   }
623
624   // Copy by sharing the underlying data storage.
625   LinearClassifierFstImpl(const LinearClassifierFstImpl &impl)
626       : CacheImpl<A>(impl),
627         data_(impl.data_),
628         num_classes_(impl.num_classes_),
629         num_groups_(impl.num_groups_) {
630     SetType("linear-classifier");
631     SetProperties(impl.Properties(), kCopyProperties);
632     SetInputSymbols(impl.InputSymbols());
633     SetOutputSymbols(impl.OutputSymbols());
634     ReserveStubSpace();
635   }
636
637   StateId Start() {
638     if (!HasStart()) {
639       StateId start = FindStartState();
640       SetStart(start);
641     }
642     return CacheImpl<A>::Start();
643   }
644
645   Weight Final(StateId s) {
646     if (!HasFinal(s)) {
647       state_stub_.clear();
648       FillState(s, &state_stub_);
649       SetFinal(s, FinalWeight(state_stub_));
650     }
651     return CacheImpl<A>::Final(s);
652   }
653
654   size_t NumArcs(StateId s) {
655     if (!HasArcs(s)) Expand(s);
656     return CacheImpl<A>::NumArcs(s);
657   }
658
659   size_t NumInputEpsilons(StateId s) {
660     if (!HasArcs(s)) Expand(s);
661     return CacheImpl<A>::NumInputEpsilons(s);
662   }
663
664   size_t NumOutputEpsilons(StateId s) {
665     if (!HasArcs(s)) Expand(s);
666     return CacheImpl<A>::NumOutputEpsilons(s);
667   }
668
669   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
670     if (!HasArcs(s)) Expand(s);
671     CacheImpl<A>::InitArcIterator(s, data);
672   }
673
674   // Computes the outgoing transitions from a state, creating new
675   // destination states as needed.
676   void Expand(StateId s);
677
678   // Appends to `arcs` all out-going arcs from state `s` that matches
679   // `label` as the input label.
680   void MatchInput(StateId s, Label ilabel, std::vector<Arc> *arcs);
681
682   static LinearClassifierFstImpl<A> *Read(std::istream &strm,
683                                           const FstReadOptions &opts);
684
685   bool Write(std::ostream &strm, const FstWriteOptions &opts) const {
686     FstHeader header;
687     header.SetStart(kNoStateId);
688     WriteHeader(strm, opts, kFileVersion, &header);
689     data_->Write(strm);
690     WriteType(strm, num_classes_);
691     if (!strm) {
692       LOG(ERROR) << "LinearClassifierFst::Write: Write failed: " << opts.source;
693       return false;
694     }
695     return true;
696   }
697
698  private:
699   static const int kMinFileVersion;
700   static const int kFileVersion;
701
702   // A collection of functions to access parts of the state tuple. A
703   // state tuple is a vector of `Label`s with two parts:
704   //   [prediction] [internal].
705   //
706   // - [prediction] is a single label of the predicted class. A state
707   //   must have a positive class label, unless it is the start state.
708   //
709   // - [internal] is the internal state tuple for `LinearFstData` of
710   //   the given class; or kNoTrieNodeId's if in start state.
711   Label &Prediction(std::vector<Label> &state) { return state[0]; }  // NOLINT
712   Label Prediction(const std::vector<Label> &state) const { return state[0]; }
713
714   Label &InternalAt(std::vector<Label> &state, int index) {  // NOLINT
715     return state[index + 1];
716   }
717   Label InternalAt(const std::vector<Label> &state, int index) const {
718     return state[index + 1];
719   }
720
721   // The size of state tuples are fixed, reserve them in stubs
722   void ReserveStubSpace() {
723     size_t size = 1 + num_groups_;
724     state_stub_.reserve(size);
725     next_stub_.reserve(size);
726   }
727
728   // Computes the start state tuple and maps it to the start state id.
729   StateId FindStartState() {
730     // A start state tuple has no prediction
731     state_stub_.clear();
732     state_stub_.push_back(kNoLabel);
733     // For a start state, we don't yet know where we are in the tries.
734     for (size_t i = 0; i < num_groups_; ++i)
735       state_stub_.push_back(kNoTrieNodeId);
736     return FindState(state_stub_);
737   }
738
739   // Tests if the state tuple represents the start state.
740   bool IsStartState(const std::vector<Label> &state) const {
741     return state[0] == kNoLabel;
742   }
743
744   // Computes the actual group id in the data storage.
745   int GroupId(Label pred, int group) const {
746     return group * num_classes_ + pred - 1;
747   }
748
749   // Finds out the final weight of the given state. A state is final
750   // iff it is not the start.
751   Weight FinalWeight(const std::vector<Label> &state) const {
752     if (IsStartState(state)) {
753       return Weight::Zero();
754     }
755     Label pred = Prediction(state);
756     DCHECK_GT(pred, 0);
757     DCHECK_LE(pred, num_classes_);
758     Weight final_weight = Weight::One();
759     for (size_t group = 0; group < num_groups_; ++group) {
760       int group_id = GroupId(pred, group);
761       int trie_state = InternalAt(state, group);
762       final_weight =
763           Times(final_weight, data_->GroupFinalWeight(group_id, trie_state));
764     }
765     return final_weight;
766   }
767
768   // Finds state corresponding to an n-gram. Creates new state if n-gram not
769   // found.
770   StateId FindState(const std::vector<Label> &ngram) {
771     StateId sparse = ngrams_.FindId(ngram, true);
772     StateId dense = condensed_.FindId(sparse, true);
773     return dense;
774   }
775
776   // Appends after `output` the state tuple corresponding to the state id. The
777   // state id must exist.
778   void FillState(StateId s, std::vector<Label> *output) {
779     s = condensed_.FindEntry(s);
780     for (NGramIterator it = ngrams_.FindSet(s); !it.Done(); it.Next()) {
781       Label label = it.Element();
782       output->push_back(label);
783     }
784   }
785
786   std::shared_ptr<const LinearFstData<A>> data_;
787   // Division of groups in `data_`; num_classes_ * num_groups_ ==
788   // data_->NumGroups().
789   size_t num_classes_, num_groups_;
790   // Mapping from internal state tuple to *non-consecutive* ids
791   Collection<StateId, Label> ngrams_;
792   // Mapping from non-consecutive id to actual state id
793   CompactHashBiTable<StateId, StateId, std::hash<StateId>> condensed_;
794   // Two frequently used vectors, reuse to avoid repeated heap
795   // allocation
796   std::vector<Label> state_stub_, next_stub_;
797
798   void operator=(const LinearClassifierFstImpl<A> &) = delete;
799 };
800
801 template <class A>
802 const int LinearClassifierFstImpl<A>::kMinFileVersion = 0;
803
804 template <class A>
805 const int LinearClassifierFstImpl<A>::kFileVersion = 0;
806
807 template <class A>
808 void LinearClassifierFstImpl<A>::Expand(StateId s) {
809   VLOG(3) << "Expand " << s;
810   state_stub_.clear();
811   FillState(s, &state_stub_);
812   next_stub_.clear();
813   next_stub_.resize(1 + num_groups_);
814
815   if (IsStartState(state_stub_)) {
816     // Make prediction
817     for (Label pred = 1; pred <= num_classes_; ++pred) {
818       Prediction(next_stub_) = pred;
819       for (int i = 0; i < num_groups_; ++i)
820         InternalAt(next_stub_, i) = data_->GroupStartState(GroupId(pred, i));
821       PushArc(s, A(0, pred, Weight::One(), FindState(next_stub_)));
822     }
823   } else {
824     Label pred = Prediction(state_stub_);
825     DCHECK_GT(pred, 0);
826     DCHECK_LE(pred, num_classes_);
827     for (Label ilabel = data_->MinInputLabel();
828          ilabel <= data_->MaxInputLabel(); ++ilabel) {
829       Prediction(next_stub_) = pred;
830       Weight weight = Weight::One();
831       for (int i = 0; i < num_groups_; ++i)
832         InternalAt(next_stub_, i) =
833             data_->GroupTransition(GroupId(pred, i), InternalAt(state_stub_, i),
834                                    ilabel, pred, &weight);
835       PushArc(s, A(ilabel, 0, weight, FindState(next_stub_)));
836     }
837   }
838
839   SetArcs(s);
840 }
841
842 template <class A>
843 void LinearClassifierFstImpl<A>::MatchInput(StateId s, Label ilabel,
844                                             std::vector<Arc> *arcs) {
845   state_stub_.clear();
846   FillState(s, &state_stub_);
847   next_stub_.clear();
848   next_stub_.resize(1 + num_groups_);
849
850   if (IsStartState(state_stub_)) {
851     // Make prediction if `ilabel` is epsilon.
852     if (ilabel == 0) {
853       for (Label pred = 1; pred <= num_classes_; ++pred) {
854         Prediction(next_stub_) = pred;
855         for (int i = 0; i < num_groups_; ++i)
856           InternalAt(next_stub_, i) = data_->GroupStartState(GroupId(pred, i));
857         arcs->push_back(A(0, pred, Weight::One(), FindState(next_stub_)));
858       }
859     }
860   } else if (ilabel != 0) {
861     Label pred = Prediction(state_stub_);
862     Weight weight = Weight::One();
863     Prediction(next_stub_) = pred;
864     for (int i = 0; i < num_groups_; ++i)
865       InternalAt(next_stub_, i) = data_->GroupTransition(
866           GroupId(pred, i), InternalAt(state_stub_, i), ilabel, pred, &weight);
867     arcs->push_back(A(ilabel, 0, weight, FindState(next_stub_)));
868   }
869 }
870
871 template <class A>
872 inline LinearClassifierFstImpl<A> *LinearClassifierFstImpl<A>::Read(
873     std::istream &strm, const FstReadOptions &opts) {
874   std::unique_ptr<LinearClassifierFstImpl<A>> impl(
875       new LinearClassifierFstImpl<A>());
876   FstHeader header;
877   if (!impl->ReadHeader(strm, opts, kMinFileVersion, &header)) {
878     return nullptr;
879   }
880   impl->data_ = std::shared_ptr<LinearFstData<A>>(LinearFstData<A>::Read(strm));
881   if (!impl->data_) {
882     return nullptr;
883   }
884   ReadType(strm, &impl->num_classes_);
885   if (!strm) {
886     return nullptr;
887   }
888   impl->num_groups_ = impl->data_->NumGroups() / impl->num_classes_;
889   if (impl->num_groups_ * impl->num_classes_ != impl->data_->NumGroups()) {
890     FSTERROR() << "Total number of feature groups is not a multiple of the "
891                   "number of classes: num groups = "
892                << impl->data_->NumGroups()
893                << ", num classes = " << impl->num_classes_;
894     return nullptr;
895   }
896   impl->ReserveStubSpace();
897   return impl.release();
898 }
899
900 }  // namespace internal
901
902 // This class attaches interface to implementation and handles
903 // reference counting, delegating most methods to ImplToFst.
904 template <class A>
905 class LinearClassifierFst
906     : public ImplToFst<internal::LinearClassifierFstImpl<A>> {
907  public:
908   friend class ArcIterator<LinearClassifierFst<A>>;
909   friend class StateIterator<LinearClassifierFst<A>>;
910   friend class LinearFstMatcherTpl<LinearClassifierFst<A>>;
911
912   typedef A Arc;
913   typedef typename A::Label Label;
914   typedef typename A::Weight Weight;
915   typedef typename A::StateId StateId;
916   typedef DefaultCacheStore<A> Store;
917   typedef typename Store::State State;
918   using Impl = internal::LinearClassifierFstImpl<A>;
919
920   LinearClassifierFst() : ImplToFst<Impl>(std::make_shared<Impl>()) {}
921
922   explicit LinearClassifierFst(LinearFstData<A> *data, size_t num_classes,
923                                const SymbolTable *isyms = nullptr,
924                                const SymbolTable *osyms = nullptr,
925                                CacheOptions opts = CacheOptions())
926       : ImplToFst<Impl>(
927             std::make_shared<Impl>(data, num_classes, isyms, osyms, opts)) {}
928
929   explicit LinearClassifierFst(const Fst<A> &fst)
930       : ImplToFst<Impl>(std::make_shared<Impl>()) {
931     LOG(FATAL) << "LinearClassifierFst: no constructor from arbitrary FST.";
932   }
933
934   // See Fst<>::Copy() for doc.
935   LinearClassifierFst(const LinearClassifierFst<A> &fst, bool safe = false)
936       : ImplToFst<Impl>(fst, safe) {}
937
938   // Get a copy of this LinearClassifierFst. See Fst<>::Copy() for further doc.
939   LinearClassifierFst<A> *Copy(bool safe = false) const override {
940     return new LinearClassifierFst<A>(*this, safe);
941   }
942
943   inline void InitStateIterator(StateIteratorData<A> *data) const override;
944
945   void InitArcIterator(StateId s, ArcIteratorData<A> *data) const override {
946     GetMutableImpl()->InitArcIterator(s, data);
947   }
948
949   MatcherBase<A> *InitMatcher(MatchType match_type) const override {
950     return new LinearFstMatcherTpl<LinearClassifierFst<A>>(*this, match_type);
951   }
952
953   static LinearClassifierFst<A> *Read(const string &filename) {
954     if (!filename.empty()) {
955       std::ifstream strm(filename,
956                               std::ios_base::in | std::ios_base::binary);
957       if (!strm) {
958         LOG(ERROR) << "LinearClassifierFst::Read: Can't open file: "
959                    << filename;
960         return nullptr;
961       }
962       return Read(strm, FstReadOptions(filename));
963     } else {
964       return Read(std::cin, FstReadOptions("standard input"));
965     }
966   }
967
968   static LinearClassifierFst<A> *Read(std::istream &in,
969                                       const FstReadOptions &opts) {
970     auto *impl = Impl::Read(in, opts);
971     return impl ? new LinearClassifierFst<A>(std::shared_ptr<Impl>(impl))
972                 : nullptr;
973   }
974
975   bool Write(const string &filename) const override {
976     if (!filename.empty()) {
977       std::ofstream strm(filename,
978                                std::ios_base::out | std::ios_base::binary);
979       if (!strm) {
980         LOG(ERROR) << "ProdLmFst::Write: Can't open file: " << filename;
981         return false;
982       }
983       return Write(strm, FstWriteOptions(filename));
984     } else {
985       return Write(std::cout, FstWriteOptions("standard output"));
986     }
987   }
988
989   bool Write(std::ostream &strm, const FstWriteOptions &opts) const override {
990     return GetImpl()->Write(strm, opts);
991   }
992
993  private:
994   using ImplToFst<Impl>::GetImpl;
995   using ImplToFst<Impl>::GetMutableImpl;
996
997   explicit LinearClassifierFst(std::shared_ptr<Impl> impl)
998       : ImplToFst<Impl>(impl) {}
999
1000   void operator=(const LinearClassifierFst<A> &fst) = delete;
1001 };
1002
1003 // Specialization for LinearClassifierFst.
1004 template <class Arc>
1005 class StateIterator<LinearClassifierFst<Arc>>
1006     : public CacheStateIterator<LinearClassifierFst<Arc>> {
1007  public:
1008   explicit StateIterator(const LinearClassifierFst<Arc> &fst)
1009       : CacheStateIterator<LinearClassifierFst<Arc>>(fst,
1010                                                      fst.GetMutableImpl()) {}
1011 };
1012
1013 // Specialization for LinearClassifierFst.
1014 template <class Arc>
1015 class ArcIterator<LinearClassifierFst<Arc>>
1016     : public CacheArcIterator<LinearClassifierFst<Arc>> {
1017  public:
1018   using StateId = typename Arc::StateId;
1019
1020   ArcIterator(const LinearClassifierFst<Arc> &fst, StateId s)
1021       : CacheArcIterator<LinearClassifierFst<Arc>>(fst.GetMutableImpl(), s) {
1022     if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
1023   }
1024 };
1025
1026 template <class Arc>
1027 inline void LinearClassifierFst<Arc>::InitStateIterator(
1028     StateIteratorData<Arc> *data) const {
1029   data->base = new StateIterator<LinearClassifierFst<Arc>>(*this);
1030 }
1031
1032 // Specialized Matcher for LinearFsts. This matcher only supports
1033 // matching from the input side. This is intentional because comparing
1034 // the scores of different input sequences with the same output
1035 // sequence is meaningless in a discriminative model.
1036 template <class F>
1037 class LinearFstMatcherTpl : public MatcherBase<typename F::Arc> {
1038  public:
1039   typedef typename F::Arc Arc;
1040   typedef typename Arc::Label Label;
1041   typedef typename Arc::Weight Weight;
1042   typedef typename Arc::StateId StateId;
1043   typedef F FST;
1044
1045   LinearFstMatcherTpl(const FST &fst, MatchType match_type)
1046       : fst_(fst.Copy()),
1047         match_type_(match_type),
1048         s_(kNoStateId),
1049         current_loop_(false),
1050         loop_(kNoLabel, 0, Weight::One(), kNoStateId),
1051         cur_arc_(0),
1052         error_(false) {
1053     switch (match_type_) {
1054       case MATCH_INPUT:
1055       case MATCH_OUTPUT:
1056       case MATCH_NONE:
1057         break;
1058       default:
1059         FSTERROR() << "LinearFstMatcherTpl: Bad match type";
1060         match_type_ = MATCH_NONE;
1061         error_ = true;
1062     }
1063   }
1064
1065   LinearFstMatcherTpl(const LinearFstMatcherTpl<F> &matcher, bool safe = false)
1066       : fst_(matcher.fst_->Copy(safe)),
1067         match_type_(matcher.match_type_),
1068         s_(kNoStateId),
1069         current_loop_(false),
1070         loop_(matcher.loop_),
1071         cur_arc_(0),
1072         error_(matcher.error_) {}
1073
1074   LinearFstMatcherTpl<F> *Copy(bool safe = false) const override {
1075     return new LinearFstMatcherTpl<F>(*this, safe);
1076   }
1077
1078   MatchType Type(bool /*test*/) const override {
1079     // `MATCH_INPUT` is the only valid type
1080     return match_type_ == MATCH_INPUT ? match_type_ : MATCH_NONE;
1081   }
1082
1083   void SetState(StateId s) final {
1084     if (s_ == s) return;
1085     s_ = s;
1086     // `MATCH_INPUT` is the only valid type
1087     if (match_type_ != MATCH_INPUT) {
1088       FSTERROR() << "LinearFstMatcherTpl: Bad match type";
1089       error_ = true;
1090     }
1091     loop_.nextstate = s;
1092   }
1093
1094   bool Find(Label label) final {
1095     if (error_) {
1096       current_loop_ = false;
1097       return false;
1098     }
1099     current_loop_ = label == 0;
1100     if (label == kNoLabel) label = 0;
1101     arcs_.clear();
1102     cur_arc_ = 0;
1103     fst_->GetMutableImpl()->MatchInput(s_, label, &arcs_);
1104     return current_loop_ || !arcs_.empty();
1105   }
1106
1107   bool Done() const final {
1108     return !(current_loop_ || cur_arc_ < arcs_.size());
1109   }
1110
1111   const Arc &Value() const final {
1112     return current_loop_ ? loop_ : arcs_[cur_arc_];
1113   }
1114
1115   void Next() final {
1116     if (current_loop_)
1117       current_loop_ = false;
1118     else
1119       ++cur_arc_;
1120   }
1121
1122   ssize_t Priority(StateId s) final { return kRequirePriority; }
1123
1124   const FST &GetFst() const override { return *fst_; }
1125
1126   uint64 Properties(uint64 props) const override {
1127     if (error_) props |= kError;
1128     return props;
1129   }
1130
1131   uint32 Flags() const override { return kRequireMatch; }
1132
1133  private:
1134   std::unique_ptr<const FST> fst_;
1135   MatchType match_type_;  // Type of match to perform.
1136   StateId s_;             // Current state.
1137   bool current_loop_;     // Current arc is the implicit loop.
1138   Arc loop_;              // For non-consuming symbols.
1139   // All out-going arcs matching the label in last Find() call.
1140   std::vector<Arc> arcs_;
1141   size_t cur_arc_;  // Index to the arc that `Value()` should return.
1142   bool error_;      // Error encountered.
1143 };
1144
1145 }  // namespace fst
1146
1147 #endif  // FST_EXTENSIONS_LINEAR_LINEAR_FST_H_