1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Classes for building, storing and representing log-linear models as FSTs.
6 #ifndef FST_EXTENSIONS_LINEAR_LINEAR_FST_H_
7 #define FST_EXTENSIONS_LINEAR_LINEAR_FST_H_
14 #include <fst/compat.h>
16 #include <fst/extensions/pdt/collection.h>
17 #include <fst/bi-table.h>
18 #include <fst/cache.h>
21 #include <fst/matcher.h>
22 #include <fst/symbol-table.h>
24 #include <fst/extensions/linear/linear-fst-data.h>
28 // Forward declaration of the specialized matcher for both
29 // LinearTaggerFst and LinearClassifierFst.
31 class LinearFstMatcherTpl;
35 // Implementation class for on-the-fly generated LinearTaggerFst with
36 // special optimization in matching.
38 class LinearTaggerFstImpl : public CacheImpl<A> {
40 using FstImpl<A>::SetType;
41 using FstImpl<A>::SetProperties;
42 using FstImpl<A>::SetInputSymbols;
43 using FstImpl<A>::SetOutputSymbols;
44 using FstImpl<A>::WriteHeader;
46 using CacheBaseImpl<CacheState<A>>::PushArc;
47 using CacheBaseImpl<CacheState<A>>::HasArcs;
48 using CacheBaseImpl<CacheState<A>>::HasFinal;
49 using CacheBaseImpl<CacheState<A>>::HasStart;
50 using CacheBaseImpl<CacheState<A>>::SetArcs;
51 using CacheBaseImpl<CacheState<A>>::SetFinal;
52 using CacheBaseImpl<CacheState<A>>::SetStart;
55 typedef typename A::Label Label;
56 typedef typename A::Weight Weight;
57 typedef typename A::StateId StateId;
58 typedef typename Collection<StateId, Label>::SetIterator NGramIterator;
60 // Constructs an empty FST by default.
62 : CacheImpl<A>(CacheOptions()),
63 data_(std::make_shared<LinearFstData<A>>()),
65 SetType("linear-tagger");
68 // Constructs the FST with given data storage and symbol
71 // TODO(wuke): when there is no constraint on output we can delay
72 // less than `data->MaxFutureSize` positions.
73 LinearTaggerFstImpl(const LinearFstData<Arc> *data, const SymbolTable *isyms,
74 const SymbolTable *osyms, CacheOptions opts)
75 : CacheImpl<A>(opts), data_(data), delay_(data->MaxFutureSize()) {
76 SetType("linear-tagger");
77 SetProperties(kILabelSorted, kFstProperties);
78 SetInputSymbols(isyms);
79 SetOutputSymbols(osyms);
83 // Copy by sharing the underlying data storage.
84 LinearTaggerFstImpl(const LinearTaggerFstImpl &impl)
85 : CacheImpl<A>(impl), data_(impl.data_), delay_(impl.delay_) {
86 SetType("linear-tagger");
87 SetProperties(impl.Properties(), kCopyProperties);
88 SetInputSymbols(impl.InputSymbols());
89 SetOutputSymbols(impl.OutputSymbols());
95 StateId start = FindStartState();
98 return CacheImpl<A>::Start();
101 Weight Final(StateId s) {
104 FillState(s, &state_stub_);
105 if (CanBeFinal(state_stub_))
106 SetFinal(s, data_->FinalWeight(InternalBegin(state_stub_),
107 InternalEnd(state_stub_)));
109 SetFinal(s, Weight::Zero());
111 return CacheImpl<A>::Final(s);
114 size_t NumArcs(StateId s) {
115 if (!HasArcs(s)) Expand(s);
116 return CacheImpl<A>::NumArcs(s);
119 size_t NumInputEpsilons(StateId s) {
120 if (!HasArcs(s)) Expand(s);
121 return CacheImpl<A>::NumInputEpsilons(s);
124 size_t NumOutputEpsilons(StateId s) {
125 if (!HasArcs(s)) Expand(s);
126 return CacheImpl<A>::NumOutputEpsilons(s);
129 void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
130 if (!HasArcs(s)) Expand(s);
131 CacheImpl<A>::InitArcIterator(s, data);
134 // Computes the outgoing transitions from a state, creating new
135 // destination states as needed.
136 void Expand(StateId s);
138 // Appends to `arcs` all out-going arcs from state `s` that matches `label` as
140 void MatchInput(StateId s, Label ilabel, std::vector<Arc> *arcs);
142 static LinearTaggerFstImpl *Read(std::istream &strm,
143 const FstReadOptions &opts);
145 bool Write(std::ostream &strm, // NOLINT
146 const FstWriteOptions &opts) const {
148 header.SetStart(kNoStateId);
149 WriteHeader(strm, opts, kFileVersion, &header);
152 LOG(ERROR) << "LinearTaggerFst::Write: Write failed: " << opts.source;
159 static const int kMinFileVersion;
160 static const int kFileVersion;
162 // A collection of functions to access parts of the state tuple. A
163 // state tuple is a vector of `Label`s with two parts:
164 // [buffer] [internal].
166 // - [buffer] is a buffer of observed input labels with length
167 // `delay_`. `LinearFstData<A>::kStartOfSentence`
168 // (resp. `LinearFstData<A>::kEndOfSentence`) are used as
169 // paddings when the buffer has fewer than `delay_` elements, which
170 // can only appear as the prefix (resp. suffix) of the buffer.
172 // - [internal] is the internal state tuple for `LinearFstData`
173 typename std::vector<Label>::const_iterator BufferBegin(
174 const std::vector<Label> &state) const {
175 return state.begin();
178 typename std::vector<Label>::const_iterator BufferEnd(
179 const std::vector<Label> &state) const {
180 return state.begin() + delay_;
183 typename std::vector<Label>::const_iterator InternalBegin(
184 const std::vector<Label> &state) const {
185 return state.begin() + delay_;
188 typename std::vector<Label>::const_iterator InternalEnd(
189 const std::vector<Label> &state) const {
193 // The size of state tuples are fixed, reserve them in stubs
194 void ReserveStubSpace() {
195 state_stub_.reserve(delay_ + data_->NumGroups());
196 next_stub_.reserve(delay_ + data_->NumGroups());
199 // Computes the start state tuple and maps it to the start state id.
200 StateId FindStartState() {
201 // Empty buffer with start-of-sentence paddings
203 state_stub_.resize(delay_, LinearFstData<A>::kStartOfSentence);
204 // Append internal states
205 data_->EncodeStartState(&state_stub_);
206 return FindState(state_stub_);
209 // Tests whether the buffer in `(begin, end)` is empty.
210 bool IsEmptyBuffer(typename std::vector<Label>::const_iterator begin,
211 typename std::vector<Label>::const_iterator end) const {
212 // The following is guanranteed by `ShiftBuffer()`:
213 // - buffer[i] == LinearFstData<A>::kEndOfSentence =>
214 // buffer[i+x] == LinearFstData<A>::kEndOfSentence
215 // - buffer[i] == LinearFstData<A>::kStartOfSentence =>
216 // buffer[i-x] == LinearFstData<A>::kStartOfSentence
217 return delay_ == 0 || *(end - 1) == LinearFstData<A>::kStartOfSentence ||
218 *begin == LinearFstData<A>::kEndOfSentence;
221 // Tests whether the given state tuple can be a final state. A state
222 // is final iff there is no observed input in the buffer.
223 bool CanBeFinal(const std::vector<Label> &state) {
224 return IsEmptyBuffer(BufferBegin(state), BufferEnd(state));
227 // Finds state corresponding to an n-gram. Creates new state if n-gram not
229 StateId FindState(const std::vector<Label> &ngram) {
230 StateId sparse = ngrams_.FindId(ngram, true);
231 StateId dense = condensed_.FindId(sparse, true);
235 // Appends after `output` the state tuple corresponding to the state id. The
236 // state id must exist.
237 void FillState(StateId s, std::vector<Label> *output) {
238 s = condensed_.FindEntry(s);
239 for (NGramIterator it = ngrams_.FindSet(s); !it.Done(); it.Next()) {
240 Label label = it.Element();
241 output->push_back(label);
245 // Shifts the buffer in `state` by appending `ilabel` and popping
246 // the one in the front as the return value. `next_stub_` is a
247 // shifted buffer of size `delay_` where the first `delay_ - 1`
248 // elements are the last `delay_ - 1` elements in the buffer of
249 // `state`. The last (if any) element in `next_stub_` will be
250 // `ilabel` after the call returns.
251 Label ShiftBuffer(const std::vector<Label> &state, Label ilabel,
252 std::vector<Label> *next_stub_);
254 // Builds an arc from state tuple `state` consuming `ilabel` and
255 // `olabel`. `next_stub_` is the buffer filled in `ShiftBuffer`.
256 Arc MakeArc(const std::vector<Label> &state, Label ilabel, Label olabel,
257 std::vector<Label> *next_stub_);
259 // Expands arcs from state `s`, equivalent to state tuple `state`,
260 // with input `ilabel`. `next_stub_` is the buffer filled in
262 void ExpandArcs(StateId s, const std::vector<Label> &state, Label ilabel,
263 std::vector<Label> *next_stub_);
265 // Appends arcs from state `s`, equivalent to state tuple `state`,
266 // with input `ilabel` to `arcs`. `next_stub_` is the buffer filled
268 void AppendArcs(StateId s, const std::vector<Label> &state, Label ilabel,
269 std::vector<Label> *next_stub_, std::vector<Arc> *arcs);
271 std::shared_ptr<const LinearFstData<A>> data_;
273 // Mapping from internal state tuple to *non-consecutive* ids
274 Collection<StateId, Label> ngrams_;
275 // Mapping from non-consecutive id to actual state id
276 CompactHashBiTable<StateId, StateId, std::hash<StateId>> condensed_;
277 // Two frequently used vectors, reuse to avoid repeated heap
279 std::vector<Label> state_stub_, next_stub_;
281 LinearTaggerFstImpl &operator=(const LinearTaggerFstImpl &) = delete;
285 const int LinearTaggerFstImpl<A>::kMinFileVersion = 1;
288 const int LinearTaggerFstImpl<A>::kFileVersion = 1;
291 inline typename A::Label LinearTaggerFstImpl<A>::ShiftBuffer(
292 const std::vector<Label> &state, Label ilabel,
293 std::vector<Label> *next_stub_) {
294 DCHECK(ilabel > 0 || ilabel == LinearFstData<A>::kEndOfSentence);
296 DCHECK_GT(ilabel, 0);
299 (*next_stub_)[BufferEnd(*next_stub_) - next_stub_->begin() - 1] = ilabel;
300 return *BufferBegin(state);
305 inline A LinearTaggerFstImpl<A>::MakeArc(const std::vector<Label> &state,
306 Label ilabel, Label olabel,
307 std::vector<Label> *next_stub_) {
308 DCHECK(ilabel > 0 || ilabel == LinearFstData<A>::kEndOfSentence);
309 DCHECK(olabel > 0 || olabel == LinearFstData<A>::kStartOfSentence);
310 Weight weight(Weight::One());
311 data_->TakeTransition(BufferEnd(state), InternalBegin(state),
312 InternalEnd(state), ilabel, olabel, next_stub_,
314 StateId nextstate = FindState(*next_stub_);
315 // Restore `next_stub_` to its size before the call
316 next_stub_->resize(delay_);
317 // In the actual arc, we use epsilons instead of boundaries.
318 return A(ilabel == LinearFstData<A>::kEndOfSentence ? 0 : ilabel,
319 olabel == LinearFstData<A>::kStartOfSentence ? 0 : olabel, weight,
324 inline void LinearTaggerFstImpl<A>::ExpandArcs(StateId s,
325 const std::vector<Label> &state,
327 std::vector<Label> *next_stub_) {
328 // Input label to constrain the output with, observed `delay_` steps
329 // back. `ilabel` is the input label to be put on the arc, which
331 Label obs_ilabel = ShiftBuffer(state, ilabel, next_stub_);
332 if (obs_ilabel == LinearFstData<A>::kStartOfSentence) {
333 // This happens when input is shorter than `delay_`.
334 PushArc(s, MakeArc(state, ilabel, LinearFstData<A>::kStartOfSentence,
337 std::pair<typename std::vector<typename A::Label>::const_iterator,
338 typename std::vector<typename A::Label>::const_iterator> range =
339 data_->PossibleOutputLabels(obs_ilabel);
340 for (typename std::vector<typename A::Label>::const_iterator it =
342 it != range.second; ++it)
343 PushArc(s, MakeArc(state, ilabel, *it, next_stub_));
347 // TODO(wuke): this has much in duplicate with `ExpandArcs()`
349 inline void LinearTaggerFstImpl<A>::AppendArcs(StateId /*s*/,
350 const std::vector<Label> &state,
352 std::vector<Label> *next_stub_,
353 std::vector<Arc> *arcs) {
354 // Input label to constrain the output with, observed `delay_` steps
355 // back. `ilabel` is the input label to be put on the arc, which
357 Label obs_ilabel = ShiftBuffer(state, ilabel, next_stub_);
358 if (obs_ilabel == LinearFstData<A>::kStartOfSentence) {
359 // This happens when input is shorter than `delay_`.
361 MakeArc(state, ilabel, LinearFstData<A>::kStartOfSentence, next_stub_));
363 std::pair<typename std::vector<typename A::Label>::const_iterator,
364 typename std::vector<typename A::Label>::const_iterator> range =
365 data_->PossibleOutputLabels(obs_ilabel);
366 for (typename std::vector<typename A::Label>::const_iterator it =
368 it != range.second; ++it)
369 arcs->push_back(MakeArc(state, ilabel, *it, next_stub_));
374 void LinearTaggerFstImpl<A>::Expand(StateId s) {
375 VLOG(3) << "Expand " << s;
377 FillState(s, &state_stub_);
379 // Precompute the first `delay_ - 1` elements in the buffer of
380 // next states, which are identical for different input/output.
382 next_stub_.resize(delay_);
384 std::copy(BufferBegin(state_stub_) + 1, BufferEnd(state_stub_),
387 // Epsilon transition for flushing out the next observed input
388 if (!IsEmptyBuffer(BufferBegin(state_stub_), BufferEnd(state_stub_)))
389 ExpandArcs(s, state_stub_, LinearFstData<A>::kEndOfSentence, &next_stub_);
391 // Non-epsilon input when we haven't flushed
393 *(BufferEnd(state_stub_) - 1) != LinearFstData<A>::kEndOfSentence)
394 for (Label ilabel = data_->MinInputLabel();
395 ilabel <= data_->MaxInputLabel(); ++ilabel)
396 ExpandArcs(s, state_stub_, ilabel, &next_stub_);
402 void LinearTaggerFstImpl<A>::MatchInput(StateId s, Label ilabel,
403 std::vector<Arc> *arcs) {
405 FillState(s, &state_stub_);
407 // Precompute the first `delay_ - 1` elements in the buffer of
408 // next states, which are identical for different input/output.
410 next_stub_.resize(delay_);
412 std::copy(BufferBegin(state_stub_) + 1, BufferEnd(state_stub_),
416 // Epsilon transition for flushing out the next observed input
417 if (!IsEmptyBuffer(BufferBegin(state_stub_), BufferEnd(state_stub_)))
418 AppendArcs(s, state_stub_, LinearFstData<A>::kEndOfSentence, &next_stub_,
421 // Non-epsilon input when we haven't flushed
423 *(BufferEnd(state_stub_) - 1) != LinearFstData<A>::kEndOfSentence)
424 AppendArcs(s, state_stub_, ilabel, &next_stub_, arcs);
429 inline LinearTaggerFstImpl<A> *LinearTaggerFstImpl<A>::Read(
430 std::istream &strm, const FstReadOptions &opts) { // NOLINT
431 std::unique_ptr<LinearTaggerFstImpl<A>> impl(new LinearTaggerFstImpl<A>());
433 if (!impl->ReadHeader(strm, opts, kMinFileVersion, &header)) {
436 impl->data_ = std::shared_ptr<LinearFstData<A>>(LinearFstData<A>::Read(strm));
440 impl->delay_ = impl->data_->MaxFutureSize();
441 impl->ReserveStubSpace();
442 return impl.release();
445 } // namespace internal
447 // This class attaches interface to implementation and handles
448 // reference counting, delegating most methods to ImplToFst.
450 class LinearTaggerFst : public ImplToFst<internal::LinearTaggerFstImpl<A>> {
452 friend class ArcIterator<LinearTaggerFst<A>>;
453 friend class StateIterator<LinearTaggerFst<A>>;
454 friend class LinearFstMatcherTpl<LinearTaggerFst<A>>;
457 typedef typename A::Label Label;
458 typedef typename A::Weight Weight;
459 typedef typename A::StateId StateId;
460 typedef DefaultCacheStore<A> Store;
461 typedef typename Store::State State;
462 using Impl = internal::LinearTaggerFstImpl<A>;
464 LinearTaggerFst() : ImplToFst<Impl>(std::make_shared<Impl>()) {}
466 explicit LinearTaggerFst(LinearFstData<A> *data,
467 const SymbolTable *isyms = nullptr,
468 const SymbolTable *osyms = nullptr,
469 CacheOptions opts = CacheOptions())
470 : ImplToFst<Impl>(std::make_shared<Impl>(data, isyms, osyms, opts)) {}
472 explicit LinearTaggerFst(const Fst<A> &fst)
473 : ImplToFst<Impl>(std::make_shared<Impl>()) {
474 LOG(FATAL) << "LinearTaggerFst: no constructor from arbitrary FST.";
477 // See Fst<>::Copy() for doc.
478 LinearTaggerFst(const LinearTaggerFst<A> &fst, bool safe = false)
479 : ImplToFst<Impl>(fst, safe) {}
481 // Get a copy of this LinearTaggerFst. See Fst<>::Copy() for further doc.
482 LinearTaggerFst<A> *Copy(bool safe = false) const override {
483 return new LinearTaggerFst<A>(*this, safe);
486 inline void InitStateIterator(StateIteratorData<A> *data) const override;
488 void InitArcIterator(StateId s, ArcIteratorData<A> *data) const override {
489 GetMutableImpl()->InitArcIterator(s, data);
492 MatcherBase<A> *InitMatcher(MatchType match_type) const override {
493 return new LinearFstMatcherTpl<LinearTaggerFst<A>>(*this, match_type);
496 static LinearTaggerFst<A> *Read(const string &filename) {
497 if (!filename.empty()) {
498 std::ifstream strm(filename,
499 std::ios_base::in | std::ios_base::binary);
501 LOG(ERROR) << "LinearTaggerFst::Read: Can't open file: " << filename;
504 return Read(strm, FstReadOptions(filename));
506 return Read(std::cin, FstReadOptions("standard input"));
510 static LinearTaggerFst<A> *Read(std::istream &in, // NOLINT
511 const FstReadOptions &opts) {
512 auto *impl = Impl::Read(in, opts);
513 return impl ? new LinearTaggerFst<A>(std::shared_ptr<Impl>(impl)) : nullptr;
516 bool Write(const string &filename) const override {
517 if (!filename.empty()) {
518 std::ofstream strm(filename,
519 std::ios_base::out | std::ios_base::binary);
521 LOG(ERROR) << "LinearTaggerFst::Write: Can't open file: " << filename;
524 return Write(strm, FstWriteOptions(filename));
526 return Write(std::cout, FstWriteOptions("standard output"));
530 bool Write(std::ostream &strm, const FstWriteOptions &opts) const override {
531 return GetImpl()->Write(strm, opts);
535 using ImplToFst<Impl>::GetImpl;
536 using ImplToFst<Impl>::GetMutableImpl;
538 explicit LinearTaggerFst(std::shared_ptr<Impl> impl)
539 : ImplToFst<Impl>(impl) {}
541 void operator=(const LinearTaggerFst<A> &fst) = delete;
544 // Specialization for LinearTaggerFst.
546 class StateIterator<LinearTaggerFst<Arc>>
547 : public CacheStateIterator<LinearTaggerFst<Arc>> {
549 explicit StateIterator(const LinearTaggerFst<Arc> &fst)
550 : CacheStateIterator<LinearTaggerFst<Arc>>(fst, fst.GetMutableImpl()) {}
553 // Specialization for LinearTaggerFst.
555 class ArcIterator<LinearTaggerFst<Arc>>
556 : public CacheArcIterator<LinearTaggerFst<Arc>> {
558 using StateId = typename Arc::StateId;
560 ArcIterator(const LinearTaggerFst<Arc> &fst, StateId s)
561 : CacheArcIterator<LinearTaggerFst<Arc>>(fst.GetMutableImpl(), s) {
562 if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
567 inline void LinearTaggerFst<Arc>::InitStateIterator(
568 StateIteratorData<Arc> *data) const {
569 data->base = new StateIterator<LinearTaggerFst<Arc>>(*this);
574 // Implementation class for on-the-fly generated LinearClassifierFst with
575 // special optimization in matching.
577 class LinearClassifierFstImpl : public CacheImpl<A> {
579 using FstImpl<A>::SetType;
580 using FstImpl<A>::SetProperties;
581 using FstImpl<A>::SetInputSymbols;
582 using FstImpl<A>::SetOutputSymbols;
583 using FstImpl<A>::WriteHeader;
585 using CacheBaseImpl<CacheState<A>>::PushArc;
586 using CacheBaseImpl<CacheState<A>>::HasArcs;
587 using CacheBaseImpl<CacheState<A>>::HasFinal;
588 using CacheBaseImpl<CacheState<A>>::HasStart;
589 using CacheBaseImpl<CacheState<A>>::SetArcs;
590 using CacheBaseImpl<CacheState<A>>::SetFinal;
591 using CacheBaseImpl<CacheState<A>>::SetStart;
594 typedef typename A::Label Label;
595 typedef typename A::Weight Weight;
596 typedef typename A::StateId StateId;
597 typedef typename Collection<StateId, Label>::SetIterator NGramIterator;
599 // Constructs an empty FST by default.
600 LinearClassifierFstImpl()
601 : CacheImpl<A>(CacheOptions()),
602 data_(std::make_shared<LinearFstData<A>>()) {
603 SetType("linear-classifier");
608 // Constructs the FST with given data storage, number of classes and
610 LinearClassifierFstImpl(const LinearFstData<Arc> *data, size_t num_classes,
611 const SymbolTable *isyms, const SymbolTable *osyms,
613 : CacheImpl<A>(opts),
615 num_classes_(num_classes),
616 num_groups_(data_->NumGroups() / num_classes_) {
617 SetType("linear-classifier");
618 SetProperties(kILabelSorted, kFstProperties);
619 SetInputSymbols(isyms);
620 SetOutputSymbols(osyms);
624 // Copy by sharing the underlying data storage.
625 LinearClassifierFstImpl(const LinearClassifierFstImpl &impl)
626 : CacheImpl<A>(impl),
628 num_classes_(impl.num_classes_),
629 num_groups_(impl.num_groups_) {
630 SetType("linear-classifier");
631 SetProperties(impl.Properties(), kCopyProperties);
632 SetInputSymbols(impl.InputSymbols());
633 SetOutputSymbols(impl.OutputSymbols());
639 StateId start = FindStartState();
642 return CacheImpl<A>::Start();
645 Weight Final(StateId s) {
648 FillState(s, &state_stub_);
649 SetFinal(s, FinalWeight(state_stub_));
651 return CacheImpl<A>::Final(s);
654 size_t NumArcs(StateId s) {
655 if (!HasArcs(s)) Expand(s);
656 return CacheImpl<A>::NumArcs(s);
659 size_t NumInputEpsilons(StateId s) {
660 if (!HasArcs(s)) Expand(s);
661 return CacheImpl<A>::NumInputEpsilons(s);
664 size_t NumOutputEpsilons(StateId s) {
665 if (!HasArcs(s)) Expand(s);
666 return CacheImpl<A>::NumOutputEpsilons(s);
669 void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
670 if (!HasArcs(s)) Expand(s);
671 CacheImpl<A>::InitArcIterator(s, data);
674 // Computes the outgoing transitions from a state, creating new
675 // destination states as needed.
676 void Expand(StateId s);
678 // Appends to `arcs` all out-going arcs from state `s` that matches
679 // `label` as the input label.
680 void MatchInput(StateId s, Label ilabel, std::vector<Arc> *arcs);
682 static LinearClassifierFstImpl<A> *Read(std::istream &strm,
683 const FstReadOptions &opts);
685 bool Write(std::ostream &strm, const FstWriteOptions &opts) const {
687 header.SetStart(kNoStateId);
688 WriteHeader(strm, opts, kFileVersion, &header);
690 WriteType(strm, num_classes_);
692 LOG(ERROR) << "LinearClassifierFst::Write: Write failed: " << opts.source;
699 static const int kMinFileVersion;
700 static const int kFileVersion;
702 // A collection of functions to access parts of the state tuple. A
703 // state tuple is a vector of `Label`s with two parts:
704 // [prediction] [internal].
706 // - [prediction] is a single label of the predicted class. A state
707 // must have a positive class label, unless it is the start state.
709 // - [internal] is the internal state tuple for `LinearFstData` of
710 // the given class; or kNoTrieNodeId's if in start state.
711 Label &Prediction(std::vector<Label> &state) { return state[0]; } // NOLINT
712 Label Prediction(const std::vector<Label> &state) const { return state[0]; }
714 Label &InternalAt(std::vector<Label> &state, int index) { // NOLINT
715 return state[index + 1];
717 Label InternalAt(const std::vector<Label> &state, int index) const {
718 return state[index + 1];
721 // The size of state tuples are fixed, reserve them in stubs
722 void ReserveStubSpace() {
723 size_t size = 1 + num_groups_;
724 state_stub_.reserve(size);
725 next_stub_.reserve(size);
728 // Computes the start state tuple and maps it to the start state id.
729 StateId FindStartState() {
730 // A start state tuple has no prediction
732 state_stub_.push_back(kNoLabel);
733 // For a start state, we don't yet know where we are in the tries.
734 for (size_t i = 0; i < num_groups_; ++i)
735 state_stub_.push_back(kNoTrieNodeId);
736 return FindState(state_stub_);
739 // Tests if the state tuple represents the start state.
740 bool IsStartState(const std::vector<Label> &state) const {
741 return state[0] == kNoLabel;
744 // Computes the actual group id in the data storage.
745 int GroupId(Label pred, int group) const {
746 return group * num_classes_ + pred - 1;
749 // Finds out the final weight of the given state. A state is final
750 // iff it is not the start.
751 Weight FinalWeight(const std::vector<Label> &state) const {
752 if (IsStartState(state)) {
753 return Weight::Zero();
755 Label pred = Prediction(state);
757 DCHECK_LE(pred, num_classes_);
758 Weight final_weight = Weight::One();
759 for (size_t group = 0; group < num_groups_; ++group) {
760 int group_id = GroupId(pred, group);
761 int trie_state = InternalAt(state, group);
763 Times(final_weight, data_->GroupFinalWeight(group_id, trie_state));
768 // Finds state corresponding to an n-gram. Creates new state if n-gram not
770 StateId FindState(const std::vector<Label> &ngram) {
771 StateId sparse = ngrams_.FindId(ngram, true);
772 StateId dense = condensed_.FindId(sparse, true);
776 // Appends after `output` the state tuple corresponding to the state id. The
777 // state id must exist.
778 void FillState(StateId s, std::vector<Label> *output) {
779 s = condensed_.FindEntry(s);
780 for (NGramIterator it = ngrams_.FindSet(s); !it.Done(); it.Next()) {
781 Label label = it.Element();
782 output->push_back(label);
786 std::shared_ptr<const LinearFstData<A>> data_;
787 // Division of groups in `data_`; num_classes_ * num_groups_ ==
788 // data_->NumGroups().
789 size_t num_classes_, num_groups_;
790 // Mapping from internal state tuple to *non-consecutive* ids
791 Collection<StateId, Label> ngrams_;
792 // Mapping from non-consecutive id to actual state id
793 CompactHashBiTable<StateId, StateId, std::hash<StateId>> condensed_;
794 // Two frequently used vectors, reuse to avoid repeated heap
796 std::vector<Label> state_stub_, next_stub_;
798 void operator=(const LinearClassifierFstImpl<A> &) = delete;
802 const int LinearClassifierFstImpl<A>::kMinFileVersion = 0;
805 const int LinearClassifierFstImpl<A>::kFileVersion = 0;
808 void LinearClassifierFstImpl<A>::Expand(StateId s) {
809 VLOG(3) << "Expand " << s;
811 FillState(s, &state_stub_);
813 next_stub_.resize(1 + num_groups_);
815 if (IsStartState(state_stub_)) {
817 for (Label pred = 1; pred <= num_classes_; ++pred) {
818 Prediction(next_stub_) = pred;
819 for (int i = 0; i < num_groups_; ++i)
820 InternalAt(next_stub_, i) = data_->GroupStartState(GroupId(pred, i));
821 PushArc(s, A(0, pred, Weight::One(), FindState(next_stub_)));
824 Label pred = Prediction(state_stub_);
826 DCHECK_LE(pred, num_classes_);
827 for (Label ilabel = data_->MinInputLabel();
828 ilabel <= data_->MaxInputLabel(); ++ilabel) {
829 Prediction(next_stub_) = pred;
830 Weight weight = Weight::One();
831 for (int i = 0; i < num_groups_; ++i)
832 InternalAt(next_stub_, i) =
833 data_->GroupTransition(GroupId(pred, i), InternalAt(state_stub_, i),
834 ilabel, pred, &weight);
835 PushArc(s, A(ilabel, 0, weight, FindState(next_stub_)));
843 void LinearClassifierFstImpl<A>::MatchInput(StateId s, Label ilabel,
844 std::vector<Arc> *arcs) {
846 FillState(s, &state_stub_);
848 next_stub_.resize(1 + num_groups_);
850 if (IsStartState(state_stub_)) {
851 // Make prediction if `ilabel` is epsilon.
853 for (Label pred = 1; pred <= num_classes_; ++pred) {
854 Prediction(next_stub_) = pred;
855 for (int i = 0; i < num_groups_; ++i)
856 InternalAt(next_stub_, i) = data_->GroupStartState(GroupId(pred, i));
857 arcs->push_back(A(0, pred, Weight::One(), FindState(next_stub_)));
860 } else if (ilabel != 0) {
861 Label pred = Prediction(state_stub_);
862 Weight weight = Weight::One();
863 Prediction(next_stub_) = pred;
864 for (int i = 0; i < num_groups_; ++i)
865 InternalAt(next_stub_, i) = data_->GroupTransition(
866 GroupId(pred, i), InternalAt(state_stub_, i), ilabel, pred, &weight);
867 arcs->push_back(A(ilabel, 0, weight, FindState(next_stub_)));
872 inline LinearClassifierFstImpl<A> *LinearClassifierFstImpl<A>::Read(
873 std::istream &strm, const FstReadOptions &opts) {
874 std::unique_ptr<LinearClassifierFstImpl<A>> impl(
875 new LinearClassifierFstImpl<A>());
877 if (!impl->ReadHeader(strm, opts, kMinFileVersion, &header)) {
880 impl->data_ = std::shared_ptr<LinearFstData<A>>(LinearFstData<A>::Read(strm));
884 ReadType(strm, &impl->num_classes_);
888 impl->num_groups_ = impl->data_->NumGroups() / impl->num_classes_;
889 if (impl->num_groups_ * impl->num_classes_ != impl->data_->NumGroups()) {
890 FSTERROR() << "Total number of feature groups is not a multiple of the "
891 "number of classes: num groups = "
892 << impl->data_->NumGroups()
893 << ", num classes = " << impl->num_classes_;
896 impl->ReserveStubSpace();
897 return impl.release();
900 } // namespace internal
902 // This class attaches interface to implementation and handles
903 // reference counting, delegating most methods to ImplToFst.
905 class LinearClassifierFst
906 : public ImplToFst<internal::LinearClassifierFstImpl<A>> {
908 friend class ArcIterator<LinearClassifierFst<A>>;
909 friend class StateIterator<LinearClassifierFst<A>>;
910 friend class LinearFstMatcherTpl<LinearClassifierFst<A>>;
913 typedef typename A::Label Label;
914 typedef typename A::Weight Weight;
915 typedef typename A::StateId StateId;
916 typedef DefaultCacheStore<A> Store;
917 typedef typename Store::State State;
918 using Impl = internal::LinearClassifierFstImpl<A>;
920 LinearClassifierFst() : ImplToFst<Impl>(std::make_shared<Impl>()) {}
922 explicit LinearClassifierFst(LinearFstData<A> *data, size_t num_classes,
923 const SymbolTable *isyms = nullptr,
924 const SymbolTable *osyms = nullptr,
925 CacheOptions opts = CacheOptions())
927 std::make_shared<Impl>(data, num_classes, isyms, osyms, opts)) {}
929 explicit LinearClassifierFst(const Fst<A> &fst)
930 : ImplToFst<Impl>(std::make_shared<Impl>()) {
931 LOG(FATAL) << "LinearClassifierFst: no constructor from arbitrary FST.";
934 // See Fst<>::Copy() for doc.
935 LinearClassifierFst(const LinearClassifierFst<A> &fst, bool safe = false)
936 : ImplToFst<Impl>(fst, safe) {}
938 // Get a copy of this LinearClassifierFst. See Fst<>::Copy() for further doc.
939 LinearClassifierFst<A> *Copy(bool safe = false) const override {
940 return new LinearClassifierFst<A>(*this, safe);
943 inline void InitStateIterator(StateIteratorData<A> *data) const override;
945 void InitArcIterator(StateId s, ArcIteratorData<A> *data) const override {
946 GetMutableImpl()->InitArcIterator(s, data);
949 MatcherBase<A> *InitMatcher(MatchType match_type) const override {
950 return new LinearFstMatcherTpl<LinearClassifierFst<A>>(*this, match_type);
953 static LinearClassifierFst<A> *Read(const string &filename) {
954 if (!filename.empty()) {
955 std::ifstream strm(filename,
956 std::ios_base::in | std::ios_base::binary);
958 LOG(ERROR) << "LinearClassifierFst::Read: Can't open file: "
962 return Read(strm, FstReadOptions(filename));
964 return Read(std::cin, FstReadOptions("standard input"));
968 static LinearClassifierFst<A> *Read(std::istream &in,
969 const FstReadOptions &opts) {
970 auto *impl = Impl::Read(in, opts);
971 return impl ? new LinearClassifierFst<A>(std::shared_ptr<Impl>(impl))
975 bool Write(const string &filename) const override {
976 if (!filename.empty()) {
977 std::ofstream strm(filename,
978 std::ios_base::out | std::ios_base::binary);
980 LOG(ERROR) << "ProdLmFst::Write: Can't open file: " << filename;
983 return Write(strm, FstWriteOptions(filename));
985 return Write(std::cout, FstWriteOptions("standard output"));
989 bool Write(std::ostream &strm, const FstWriteOptions &opts) const override {
990 return GetImpl()->Write(strm, opts);
994 using ImplToFst<Impl>::GetImpl;
995 using ImplToFst<Impl>::GetMutableImpl;
997 explicit LinearClassifierFst(std::shared_ptr<Impl> impl)
998 : ImplToFst<Impl>(impl) {}
1000 void operator=(const LinearClassifierFst<A> &fst) = delete;
1003 // Specialization for LinearClassifierFst.
1004 template <class Arc>
1005 class StateIterator<LinearClassifierFst<Arc>>
1006 : public CacheStateIterator<LinearClassifierFst<Arc>> {
1008 explicit StateIterator(const LinearClassifierFst<Arc> &fst)
1009 : CacheStateIterator<LinearClassifierFst<Arc>>(fst,
1010 fst.GetMutableImpl()) {}
1013 // Specialization for LinearClassifierFst.
1014 template <class Arc>
1015 class ArcIterator<LinearClassifierFst<Arc>>
1016 : public CacheArcIterator<LinearClassifierFst<Arc>> {
1018 using StateId = typename Arc::StateId;
1020 ArcIterator(const LinearClassifierFst<Arc> &fst, StateId s)
1021 : CacheArcIterator<LinearClassifierFst<Arc>>(fst.GetMutableImpl(), s) {
1022 if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
1026 template <class Arc>
1027 inline void LinearClassifierFst<Arc>::InitStateIterator(
1028 StateIteratorData<Arc> *data) const {
1029 data->base = new StateIterator<LinearClassifierFst<Arc>>(*this);
1032 // Specialized Matcher for LinearFsts. This matcher only supports
1033 // matching from the input side. This is intentional because comparing
1034 // the scores of different input sequences with the same output
1035 // sequence is meaningless in a discriminative model.
1037 class LinearFstMatcherTpl : public MatcherBase<typename F::Arc> {
1039 typedef typename F::Arc Arc;
1040 typedef typename Arc::Label Label;
1041 typedef typename Arc::Weight Weight;
1042 typedef typename Arc::StateId StateId;
1045 LinearFstMatcherTpl(const FST &fst, MatchType match_type)
1047 match_type_(match_type),
1049 current_loop_(false),
1050 loop_(kNoLabel, 0, Weight::One(), kNoStateId),
1053 switch (match_type_) {
1059 FSTERROR() << "LinearFstMatcherTpl: Bad match type";
1060 match_type_ = MATCH_NONE;
1065 LinearFstMatcherTpl(const LinearFstMatcherTpl<F> &matcher, bool safe = false)
1066 : fst_(matcher.fst_->Copy(safe)),
1067 match_type_(matcher.match_type_),
1069 current_loop_(false),
1070 loop_(matcher.loop_),
1072 error_(matcher.error_) {}
1074 LinearFstMatcherTpl<F> *Copy(bool safe = false) const override {
1075 return new LinearFstMatcherTpl<F>(*this, safe);
1078 MatchType Type(bool /*test*/) const override {
1079 // `MATCH_INPUT` is the only valid type
1080 return match_type_ == MATCH_INPUT ? match_type_ : MATCH_NONE;
1083 void SetState(StateId s) final {
1084 if (s_ == s) return;
1086 // `MATCH_INPUT` is the only valid type
1087 if (match_type_ != MATCH_INPUT) {
1088 FSTERROR() << "LinearFstMatcherTpl: Bad match type";
1091 loop_.nextstate = s;
1094 bool Find(Label label) final {
1096 current_loop_ = false;
1099 current_loop_ = label == 0;
1100 if (label == kNoLabel) label = 0;
1103 fst_->GetMutableImpl()->MatchInput(s_, label, &arcs_);
1104 return current_loop_ || !arcs_.empty();
1107 bool Done() const final {
1108 return !(current_loop_ || cur_arc_ < arcs_.size());
1111 const Arc &Value() const final {
1112 return current_loop_ ? loop_ : arcs_[cur_arc_];
1117 current_loop_ = false;
1122 ssize_t Priority(StateId s) final { return kRequirePriority; }
1124 const FST &GetFst() const override { return *fst_; }
1126 uint64 Properties(uint64 props) const override {
1127 if (error_) props |= kError;
1131 uint32 Flags() const override { return kRequireMatch; }
1134 std::unique_ptr<const FST> fst_;
1135 MatchType match_type_; // Type of match to perform.
1136 StateId s_; // Current state.
1137 bool current_loop_; // Current arc is the implicit loop.
1138 Arc loop_; // For non-consuming symbols.
1139 // All out-going arcs matching the label in last Find() call.
1140 std::vector<Arc> arcs_;
1141 size_t cur_arc_; // Index to the arc that `Value()` should return.
1142 bool error_; // Error encountered.
1147 #endif // FST_EXTENSIONS_LINEAR_LINEAR_FST_H_