Imported Upstream version 1.6.6
[platform/upstream/openfst.git] / src / include / fst / determinize.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions and classes to determinize an FST.
5
6 #ifndef FST_DETERMINIZE_H_
7 #define FST_DETERMINIZE_H_
8
9 #include <algorithm>
10 #include <climits>
11 #include <forward_list>
12 #include <map>
13 #include <string>
14 #include <vector>
15
16 #include <fst/log.h>
17
18 #include <fst/arc-map.h>
19 #include <fst/bi-table.h>
20 #include <fst/cache.h>
21 #include <fst/factor-weight.h>
22 #include <fst/filter-state.h>
23 #include <fst/prune.h>
24 #include <fst/test-properties.h>
25
26
27 namespace fst {
28
29 // Common divisors are used in determinization to compute transition weights.
30 // In the simplest case, it is the same as semiring Plus, but other choices
31 // permit more efficient determinization when the output contains strings.
32
33 // The default common divisor uses the semiring Plus.
34 template <class W>
35 struct DefaultCommonDivisor {
36  public:
37   using Weight = W;
38
39   Weight operator()(const Weight &w1, const Weight &w2) const {
40     return Plus(w1, w2);
41   }
42 };
43
44 // The label common divisor for a (left) string semiring selects a single
45 // letter common prefix or the empty string. This is used in the
46 // determinization of output strings so that at most a single letter will
47 // appear in the output of a transtion.
48 template <typename Label, StringType S>
49 struct LabelCommonDivisor {
50  public:
51   using Weight = StringWeight<Label, S>;
52
53   Weight operator()(const Weight &w1, const Weight &w2) const {
54     typename Weight::Iterator iter1(w1);
55     typename Weight::Iterator iter2(w2);
56     if (!(StringWeight<Label, S>::Properties() & kLeftSemiring)) {
57       FSTERROR() << "LabelCommonDivisor: Weight needs to be left semiring";
58       return Weight::NoWeight();
59     } else if (w1.Size() == 0 || w2.Size() == 0) {
60       return Weight::One();
61     } else if (w1 == Weight::Zero()) {
62       return Weight(iter2.Value());
63     } else if (w2 == Weight::Zero()) {
64       return Weight(iter1.Value());
65     } else if (iter1.Value() == iter2.Value()) {
66       return Weight(iter1.Value());
67     } else {
68       return Weight::One();
69     }
70   }
71 };
72
73 // The gallic common divisor uses the label common divisor on the string
74 // component and the common divisor on the weight component, which defaults to
75 // the default common divisor.
76 template <class Label, class W, GallicType G,
77           class CommonDivisor = DefaultCommonDivisor<W>>
78 class GallicCommonDivisor {
79  public:
80   using Weight = GallicWeight<Label, W, G>;
81
82   Weight operator()(const Weight &w1, const Weight &w2) const {
83     return Weight(label_common_divisor_(w1.Value1(), w2.Value1()),
84                   weight_common_divisor_(w1.Value2(), w2.Value2()));
85   }
86
87  private:
88   LabelCommonDivisor<Label, GallicStringType(G)> label_common_divisor_;
89   CommonDivisor weight_common_divisor_;
90 };
91
92 // Specialization for general GALLIC weight.
93 template <class Label, class W, class CommonDivisor>
94 class GallicCommonDivisor<Label, W, GALLIC, CommonDivisor> {
95  public:
96   using Weight = GallicWeight<Label, W, GALLIC>;
97   using GRWeight = GallicWeight<Label, W, GALLIC_RESTRICT>;
98   using Iterator =
99       UnionWeightIterator<GRWeight, GallicUnionWeightOptions<Label, W>>;
100
101   Weight operator()(const Weight &w1, const Weight &w2) const {
102     auto weight = GRWeight::Zero();
103     for (Iterator iter(w1); !iter.Done(); iter.Next()) {
104       weight = common_divisor_(weight, iter.Value());
105     }
106     for (Iterator iter(w2); !iter.Done(); iter.Next()) {
107       weight = common_divisor_(weight, iter.Value());
108     }
109     return weight == GRWeight::Zero() ? Weight::Zero() : Weight(weight);
110   }
111
112  private:
113   GallicCommonDivisor<Label, W, GALLIC_RESTRICT, CommonDivisor> common_divisor_;
114 };
115
116 namespace internal {
117
118 // Represents an element in a subset
119 template <class Arc>
120 struct DeterminizeElement {
121   using StateId = typename Arc::StateId;
122   using Weight = typename Arc::Weight;
123
124   DeterminizeElement(StateId s, Weight weight)
125       : state_id(s), weight(std::move(weight)) {}
126
127   inline bool operator==(const DeterminizeElement<Arc> &element) const {
128     return state_id == element.state_id && weight == element.weight;
129   }
130
131   inline bool operator!=(const DeterminizeElement<Arc> &element) const {
132     return !(*this == element);
133   }
134
135   inline bool operator<(const DeterminizeElement<Arc> &element) const {
136     return state_id < element.state_id;
137   }
138
139   StateId state_id;  // Input state ID.
140   Weight weight;     // Residual weight.
141 };
142
143 // Represents a weighted subset and determinization filter state
144 template <typename A, typename FilterState>
145 struct DeterminizeStateTuple {
146   using Arc = A;
147   using Element = DeterminizeElement<Arc>;
148   using Subset = std::forward_list<Element>;
149
150   DeterminizeStateTuple() : filter_state(FilterState::NoState()) {}
151
152   inline bool operator==(
153       const DeterminizeStateTuple<Arc, FilterState> &tuple) const {
154     return (tuple.filter_state == filter_state) && (tuple.subset == subset);
155   }
156
157   inline bool operator!=(
158       const DeterminizeStateTuple<Arc, FilterState> &tuple) const {
159     return (tuple.filter_state != filter_state) || (tuple.subset != subset);
160   }
161
162   Subset subset;
163   FilterState filter_state;
164 };
165
166 // Proto-transition for determinization.
167 template <class StateTuple>
168 struct DeterminizeArc {
169   using Arc = typename StateTuple::Arc;
170   using Label = typename Arc::Label;
171   using Weight = typename Arc::Weight;
172
173   DeterminizeArc()
174       : label(kNoLabel), weight(Weight::Zero()), dest_tuple(nullptr) {}
175
176   explicit DeterminizeArc(const Arc &arc)
177       : label(arc.ilabel), weight(Weight::Zero()), dest_tuple(new StateTuple) {}
178
179   Label label;             // Arc label.
180   Weight weight;           // Arc weight.
181   StateTuple *dest_tuple;  // Destination subset and filter state.
182 };
183
184 }  // namespace internal
185
186 // Determinization filters are used to compute destination state tuples based
187 // on the source tuple, transition, and destination element or on similar
188 // super-final transition information. The filter operates on a map between a
189 // label and the corresponding destination state tuples. It must define the map
190 // type LabelMap. The default filter is used for weighted determinization.
191 // A determinize filter for implementing weighted determinization.
192 template <class Arc>
193 class DefaultDeterminizeFilter {
194  public:
195   using Label = typename Arc::Label;
196   using StateId = typename Arc::StateId;
197   using Weight = typename Arc::Weight;
198
199   using FilterState = CharFilterState;
200   using Element = internal::DeterminizeElement<Arc>;
201   using StateTuple = internal::DeterminizeStateTuple<Arc, FilterState>;
202   using LabelMap = std::map<Label, internal::DeterminizeArc<StateTuple>>;
203
204   // This is needed e.g. to go into the gallic domain for transducers.
205   template <class A>
206   struct rebind {
207     using Other = DefaultDeterminizeFilter<A>;
208   };
209
210   explicit DefaultDeterminizeFilter(const Fst<Arc> &fst) : fst_(fst.Copy()) {}
211
212   // This is needed (e.g.) to go into the gallic domain for transducers.
213   // Ownership of the templated filter argument is given to this class.
214   template <class Filter>
215   DefaultDeterminizeFilter(const Fst<Arc> &fst, Filter *filter)
216       : fst_(fst.Copy()) {
217     delete filter;
218   }
219
220   // Copy constructor; the FST can be passed if it has been deep-copied.
221   DefaultDeterminizeFilter(const DefaultDeterminizeFilter<Arc> &filter,
222                            const Fst<Arc> *fst = nullptr)
223       : fst_(fst ? fst->Copy() : filter.fst_->Copy()) {}
224
225   FilterState Start() const { return FilterState(0); }
226
227   // Does no work.
228   void SetState(StateId s, const StateTuple &tuple) {}
229
230   // Filters transition, possibly modifying label map. Returns true if arc is
231   // added to the label map.
232   bool FilterArc(const Arc &arc, const Element &src_element,
233                  const Element &dest_element, LabelMap *label_map) const {
234     // Adds element to unique state tuple for arc label.
235     auto &det_arc = (*label_map)[arc.ilabel];
236     if (det_arc.label == kNoLabel) {
237       det_arc = internal::DeterminizeArc<StateTuple>(arc);
238       det_arc.dest_tuple->filter_state = FilterState(0);
239     }
240     det_arc.dest_tuple->subset.push_front(dest_element);
241     return true;
242   }
243
244   // Filters super-final transition, returning new final weight.
245   Weight FilterFinal(Weight weight, const Element &element) { return weight; }
246
247   static uint64 Properties(uint64 props) { return props; }
248
249  private:
250   std::unique_ptr<Fst<Arc>> fst_;
251 };
252
253 // Determinization state table interface:
254 //
255 // template <class Arc, class FilterState>
256 // class DeterminizeStateTable {
257 //  public:
258 //   using StateId = typename Arc::StateId;
259 //   using StateTuple = internal::DeterminizeStateTuple<Arc, FilterState>;
260 //
261 //   // Required sub-class. This is needed (e.g.) to go into the gallic domain.
262 //   template <class B, class G>
263 //   struct rebind {
264 //     using Other = DeterminizeStateTable<B, G>;
265 //   }
266 //
267 //   // Required constuctor.
268 //   DeterminizeStateTable();
269 //
270 //   // Required copy constructor that does not copy state.
271 //   DeterminizeStateTable(const DeterminizeStateTable<Arc, FilterState>
272 //   &table);
273 //
274 //   // Looks up state ID by state tuple; if it doesn't exist, then adds it.
275 //   // FindState takes ownership of the state tuple argument so that it
276 //   // doesn't have to copy it if it creates a new state.
277 //   StateId FindState(StateTuple *tuple);
278 //
279 //   // Looks up state tuple by ID.
280 //   const StateTuple *Tuple(StateId id) const;
281 // };
282
283 // The default determinization state table based on the compact hash bi-table.
284 template <class Arc, class FilterState>
285 class DefaultDeterminizeStateTable {
286  public:
287   using Label = typename Arc::Label;
288   using StateId = typename Arc::StateId;
289   using Weight = typename Arc::Weight;
290
291   using StateTuple = internal::DeterminizeStateTuple<Arc, FilterState>;
292   using Element = typename StateTuple::Element;
293   using Subset = typename StateTuple::Subset;
294
295   template <class B, class G>
296   struct rebind {
297     using Other = DefaultDeterminizeStateTable<B, G>;
298   };
299
300   explicit DefaultDeterminizeStateTable(size_t table_size = 0)
301       : table_size_(table_size), tuples_(table_size_) {}
302
303   DefaultDeterminizeStateTable(
304       const DefaultDeterminizeStateTable<Arc, FilterState> &table)
305       : table_size_(table.table_size_), tuples_(table_size_) {}
306
307   ~DefaultDeterminizeStateTable() {
308     for (StateId s = 0; s < tuples_.Size(); ++s) delete tuples_.FindEntry(s);
309   }
310
311   // Finds the state corresponding to a state tuple. Only creates a new state if
312   // the tuple is not found. FindState takes ownership of the tuple argument so
313   // that it doesn't have to copy it if it creates a new state.
314   StateId FindState(StateTuple *tuple) {
315     const StateId ns = tuples_.Size();
316     const auto s = tuples_.FindId(tuple);
317     if (s != ns) delete tuple;  // Tuple found.
318     return s;
319   }
320
321   const StateTuple *Tuple(StateId s) { return tuples_.FindEntry(s); }
322
323  private:
324   // Comparison object for StateTuples.
325   class StateTupleEqual {
326    public:
327     bool operator()(const StateTuple *tuple1, const StateTuple *tuple2) const {
328       return *tuple1 == *tuple2;
329     }
330   };
331
332   // Hash function for StateTuples.
333   class StateTupleKey {
334    public:
335     size_t operator()(const StateTuple *tuple) const {
336       size_t h = tuple->filter_state.Hash();
337       for (auto it = tuple->subset.begin(); it != tuple->subset.end(); ++it) {
338         const size_t h1 = it->state_id;
339         static constexpr auto lshift = 5;
340         static constexpr auto rshift = CHAR_BIT * sizeof(size_t) - 5;
341         h ^= h << 1 ^ h1 << lshift ^ h1 >> rshift ^ it->weight.Hash();
342       }
343       return h;
344     }
345   };
346
347   size_t table_size_;
348   CompactHashBiTable<StateId, StateTuple *, StateTupleKey, StateTupleEqual,
349                      HS_STL>
350       tuples_;
351
352   DefaultDeterminizeStateTable &operator=(
353       const DefaultDeterminizeStateTable &) = delete;
354 };
355
356 // Determinization type.
357 enum DeterminizeType {
358   // Input transducer is known to be functional (or error).
359   DETERMINIZE_FUNCTIONAL,  // Input transducer is functional (error if not).
360   // Input transducer is not known to be functional.
361   DETERMINIZE_NONFUNCTIONAL,
362   // Input transducer is not known to be functional but only keep the min of
363   // of ambiguous outputs.
364   DETERMINIZE_DISAMBIGUATE
365 };
366
367 // Options for finite-state transducer determinization templated on the arc
368 // type, common divisor, the determinization filter and the state table.
369 // DeterminizeFst takes ownership of the determinization filter and state table,
370 // if provided.
371 template <class Arc,
372           class CommonDivisor = DefaultCommonDivisor<typename Arc::Weight>,
373           class Filter = DefaultDeterminizeFilter<Arc>,
374           class StateTable =
375               DefaultDeterminizeStateTable<Arc, typename Filter::FilterState>>
376 struct DeterminizeFstOptions : public CacheOptions {
377   using Label = typename Arc::Label;
378
379   float delta;                // Quantization delta for subset weights.
380   Label subsequential_label;  // Label used for residual final output
381                               // when producing subsequential transducers.
382   DeterminizeType type;       // Determinization type.
383   bool increment_subsequential_label;  // When creating several subsequential
384                                        // arcs at a given state, make their
385                                        // label distinct by incrementing.
386   Filter *filter;                      // Determinization filter;
387                                        // DeterminizeFst takes ownership.
388   StateTable *state_table;             // Determinization state table;
389                                        // DeterminizeFst takes ownership.
390
391   explicit DeterminizeFstOptions(const CacheOptions &opts, float delta = kDelta,
392                                  Label subsequential_label = 0,
393                                  DeterminizeType type = DETERMINIZE_FUNCTIONAL,
394                                  bool increment_subsequential_label = false,
395                                  Filter *filter = nullptr,
396                                  StateTable *state_table = nullptr)
397       : CacheOptions(opts),
398         delta(delta),
399         subsequential_label(subsequential_label),
400         type(type),
401         increment_subsequential_label(increment_subsequential_label),
402         filter(filter),
403         state_table(state_table) {}
404
405   explicit DeterminizeFstOptions(float delta = kDelta,
406                                  Label subsequential_label = 0,
407                                  DeterminizeType type = DETERMINIZE_FUNCTIONAL,
408                                  bool increment_subsequential_label = false,
409                                  Filter *filter = nullptr,
410                                  StateTable *state_table = nullptr)
411       : delta(delta),
412         subsequential_label(subsequential_label),
413         type(type),
414         increment_subsequential_label(increment_subsequential_label),
415         filter(filter),
416         state_table(state_table) {}
417 };
418
419 namespace internal {
420
421 // Implementation of delayed DeterminizeFst. This base class is
422 // common to the variants that implement acceptor and transducer
423 // determinization.
424 template <class Arc>
425 class DeterminizeFstImplBase : public CacheImpl<Arc> {
426  public:
427   using Label = typename Arc::Label;
428   using StateId = typename Arc::StateId;
429   using Weight = typename Arc::Weight;
430
431   using Store = DefaultCacheStore<Arc>;
432   using State = typename Store::State;
433
434   using FstImpl<Arc>::SetType;
435   using FstImpl<Arc>::SetProperties;
436   using FstImpl<Arc>::Properties;
437   using FstImpl<Arc>::SetInputSymbols;
438   using FstImpl<Arc>::SetOutputSymbols;
439
440   using CacheBaseImpl<CacheState<Arc>>::HasStart;
441   using CacheBaseImpl<CacheState<Arc>>::HasFinal;
442   using CacheBaseImpl<CacheState<Arc>>::HasArcs;
443   using CacheBaseImpl<CacheState<Arc>>::SetFinal;
444   using CacheBaseImpl<CacheState<Arc>>::SetStart;
445
446   template <class CommonDivisor, class Filter, class StateTable>
447   DeterminizeFstImplBase(
448       const Fst<Arc> &fst,
449       const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable> &opts)
450       : CacheImpl<Arc>(opts), fst_(fst.Copy()) {
451     SetType("determinize");
452     const auto iprops = fst.Properties(kFstProperties, false);
453     const auto dprops =
454         DeterminizeProperties(iprops, opts.subsequential_label != 0,
455                               opts.type == DETERMINIZE_NONFUNCTIONAL
456                                   ? opts.increment_subsequential_label
457                                   : true);
458     SetProperties(Filter::Properties(dprops), kCopyProperties);
459     SetInputSymbols(fst.InputSymbols());
460     SetOutputSymbols(fst.OutputSymbols());
461   }
462
463   DeterminizeFstImplBase(const DeterminizeFstImplBase<Arc> &impl)
464       : CacheImpl<Arc>(impl), fst_(impl.fst_->Copy(true)) {
465     SetType("determinize");
466     SetProperties(impl.Properties(), kCopyProperties);
467     SetInputSymbols(impl.InputSymbols());
468     SetOutputSymbols(impl.OutputSymbols());
469   }
470
471   virtual DeterminizeFstImplBase<Arc> *Copy() const = 0;
472
473   StateId Start() {
474     if (!HasStart()) {
475       const auto start = ComputeStart();
476       if (start != kNoStateId) SetStart(start);
477     }
478     return CacheImpl<Arc>::Start();
479   }
480
481   Weight Final(StateId s) {
482     if (!HasFinal(s)) SetFinal(s, ComputeFinal(s));
483     return CacheImpl<Arc>::Final(s);
484   }
485
486   virtual void Expand(StateId s) = 0;
487
488   size_t NumArcs(StateId s) {
489     if (!HasArcs(s)) Expand(s);
490     return CacheImpl<Arc>::NumArcs(s);
491   }
492
493   size_t NumInputEpsilons(StateId s) {
494     if (!HasArcs(s)) Expand(s);
495     return CacheImpl<Arc>::NumInputEpsilons(s);
496   }
497
498   size_t NumOutputEpsilons(StateId s) {
499     if (!HasArcs(s)) Expand(s);
500     return CacheImpl<Arc>::NumOutputEpsilons(s);
501   }
502
503   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
504     if (!HasArcs(s)) Expand(s);
505     CacheImpl<Arc>::InitArcIterator(s, data);
506   }
507
508   virtual StateId ComputeStart() = 0;
509
510   virtual Weight ComputeFinal(StateId s) = 0;
511
512   const Fst<Arc> &GetFst() const { return *fst_; }
513
514  private:
515   std::unique_ptr<const Fst<Arc>> fst_;  // Input FST.
516 };
517
518 // Implementation of delayed determinization for weighted acceptors.
519 template <class Arc, class CommonDivisor, class Filter, class StateTable>
520 class DeterminizeFsaImpl : public DeterminizeFstImplBase<Arc> {
521  public:
522   using Label = typename Arc::Label;
523   using StateId = typename Arc::StateId;
524   using Weight = typename Arc::Weight;
525
526   using FilterState = typename Filter::FilterState;
527   using StateTuple = internal::DeterminizeStateTuple<Arc, FilterState>;
528   using Element = typename StateTuple::Element;
529   using Subset = typename StateTuple::Subset;
530   using LabelMap = typename Filter::LabelMap;
531
532   using FstImpl<Arc>::SetProperties;
533   using DeterminizeFstImplBase<Arc>::GetFst;
534   using DeterminizeFstImplBase<Arc>::SetArcs;
535
536   DeterminizeFsaImpl(
537       const Fst<Arc> &fst, const std::vector<Weight> *in_dist,
538       std::vector<Weight> *out_dist,
539       const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable> &opts)
540       : DeterminizeFstImplBase<Arc>(fst, opts),
541         delta_(opts.delta),
542         in_dist_(in_dist),
543         out_dist_(out_dist),
544         filter_(opts.filter ? opts.filter : new Filter(fst)),
545         state_table_(opts.state_table ? opts.state_table : new StateTable()) {
546     if (!fst.Properties(kAcceptor, true)) {
547       FSTERROR() << "DeterminizeFst: Argument not an acceptor";
548       SetProperties(kError, kError);
549     }
550     if (!(Weight::Properties() & kLeftSemiring)) {
551       FSTERROR() << "DeterminizeFst: Weight must be left distributive: "
552                  << Weight::Type();
553       SetProperties(kError, kError);
554     }
555     if (out_dist_) out_dist_->clear();
556   }
557
558   DeterminizeFsaImpl(
559       const DeterminizeFsaImpl<Arc, CommonDivisor, Filter, StateTable> &impl)
560       : DeterminizeFstImplBase<Arc>(impl),
561         delta_(impl.delta_),
562         in_dist_(nullptr),
563         out_dist_(nullptr),
564         filter_(new Filter(*impl.filter_, &GetFst())),
565         state_table_(new StateTable(*impl.state_table_)) {
566     if (impl.out_dist_) {
567       FSTERROR() << "DeterminizeFsaImpl: Cannot copy with out_dist vector";
568       SetProperties(kError, kError);
569     }
570   }
571
572   DeterminizeFsaImpl<Arc, CommonDivisor, Filter, StateTable> *Copy()
573       const override {
574     return new DeterminizeFsaImpl<Arc, CommonDivisor, Filter, StateTable>(
575         *this);
576   }
577
578   uint64 Properties() const override { return Properties(kFstProperties); }
579
580   // Sets error if found, and returns other FST impl properties.
581   uint64 Properties(uint64 mask) const override {
582     if ((mask & kError) && (GetFst().Properties(kError, false))) {
583       SetProperties(kError, kError);
584     }
585     return FstImpl<Arc>::Properties(mask);
586   }
587
588   StateId ComputeStart() override {
589     const auto s = GetFst().Start();
590     if (s == kNoStateId) return kNoStateId;
591     const Element element(s, Weight::One());
592     auto *tuple = new StateTuple;
593     tuple->subset.push_front(element);
594     tuple->filter_state = filter_->Start();
595     return FindState(tuple);
596   }
597
598   Weight ComputeFinal(StateId s) override {
599     const auto *tuple = state_table_->Tuple(s);
600     filter_->SetState(s, *tuple);
601     auto final_weight = Weight::Zero();
602     for (auto it = tuple->subset.begin(); it != tuple->subset.end(); ++it) {
603       const auto &element = *it;
604       final_weight =
605           Plus(final_weight,
606                Times(element.weight, GetFst().Final(element.state_id)));
607       final_weight = filter_->FilterFinal(final_weight, element);
608       if (!final_weight.Member()) SetProperties(kError, kError);
609     }
610     return final_weight;
611   }
612
613   StateId FindState(StateTuple *tuple) {
614     const auto s = state_table_->FindState(tuple);
615     if (in_dist_ && out_dist_->size() <= s) {
616       out_dist_->push_back(ComputeDistance(tuple->subset));
617     }
618     return s;
619   }
620
621   // Computes distance from a state to the final states in the DFA given the
622   // distances in the NFA.
623   Weight ComputeDistance(const Subset &subset) {
624     auto outd = Weight::Zero();
625     for (auto it = subset.begin(); it != subset.end(); ++it) {
626       const auto &element = *it;
627       const auto ind =
628           (element.state_id < in_dist_->size() ? (*in_dist_)[element.state_id]
629                                                : Weight::Zero());
630       outd = Plus(outd, Times(element.weight, ind));
631     }
632     return outd;
633   }
634
635   // Computes the outgoing transitions from a state, creating new destination
636   // states as needed.
637   void Expand(StateId s) override {
638     LabelMap label_map;
639     GetLabelMap(s, &label_map);
640     for (auto it = label_map.begin(); it != label_map.end(); ++it) {
641       AddArc(s, it->second);
642     }
643     SetArcs(s);
644   }
645
646  private:
647   using DetArc = internal::DeterminizeArc<StateTuple>;
648
649   // Constructs proto-determinization transition, including destination subset,
650   // per label.
651   void GetLabelMap(StateId s, LabelMap *label_map) {
652     const auto *src_tuple = state_table_->Tuple(s);
653     filter_->SetState(s, *src_tuple);
654     for (auto it = src_tuple->subset.begin(); it != src_tuple->subset.end();
655          ++it) {
656       const auto &src_element = *it;
657       for (ArcIterator<Fst<Arc>> aiter(GetFst(), src_element.state_id);
658            !aiter.Done(); aiter.Next()) {
659         const auto &arc = aiter.Value();
660         const Element dest_element(arc.nextstate,
661                                    Times(src_element.weight, arc.weight));
662         filter_->FilterArc(arc, src_element, dest_element, label_map);
663       }
664     }
665     for (auto it = label_map->begin(); it != label_map->end(); ++it) {
666       NormArc(&it->second);
667     }
668   }
669
670   // Sorts subsets and removes duplicate elements, normalizing transition and
671   // subset weights.
672   void NormArc(DetArc *det_arc) {
673     auto *dest_tuple = det_arc->dest_tuple;
674     dest_tuple->subset.sort();
675     auto piter = dest_tuple->subset.begin();
676     for (auto diter = dest_tuple->subset.begin();
677          diter != dest_tuple->subset.end();) {
678       auto &dest_element = *diter;
679       auto &prev_element = *piter;
680       // Computes arc weight.
681       det_arc->weight = common_divisor_(det_arc->weight, dest_element.weight);
682       if (piter != diter && dest_element.state_id == prev_element.state_id) {
683         // Found duplicate state: sums state weight and deletes duplicate.
684         prev_element.weight = Plus(prev_element.weight, dest_element.weight);
685         if (!prev_element.weight.Member()) SetProperties(kError, kError);
686         ++diter;
687         dest_tuple->subset.erase_after(piter);
688       } else {
689         piter = diter;
690         ++diter;
691       }
692     }
693     // Divides out label weight from destination subset elements, quantizing to
694     // ensure comparisons are effective.
695     for (auto diter = dest_tuple->subset.begin();
696          diter != dest_tuple->subset.end(); ++diter) {
697       auto &dest_element = *diter;
698       dest_element.weight =
699           Divide(dest_element.weight, det_arc->weight, DIVIDE_LEFT);
700       dest_element.weight = dest_element.weight.Quantize(delta_);
701     }
702   }
703
704   // Adds an arc from state S to the destination state associated with state
705   // tuple in det_arc as created by GetLabelMap.
706   void AddArc(StateId s, const DetArc &det_arc) {
707     const Arc arc(det_arc.label, det_arc.label, det_arc.weight,
708                   FindState(det_arc.dest_tuple));
709     CacheImpl<Arc>::PushArc(s, arc);
710   }
711
712   float delta_;                         // Quantization delta for weights.
713   const std::vector<Weight> *in_dist_;  // Distance to final NFA states.
714   std::vector<Weight> *out_dist_;       // Distance to final DFA states.
715
716   // FIXME(kbg): Ought to be static const?
717   CommonDivisor common_divisor_;
718   std::unique_ptr<Filter> filter_;
719   std::unique_ptr<StateTable> state_table_;
720 };
721
722 // Implementation of delayed determinization for transducers. Transducer
723 // determinization is implemented by mapping the input to the Gallic semiring as
724 // an acceptor whose weights contain the output strings and using acceptor
725 // determinization above to determinize that acceptor.
726 template <class Arc, GallicType G, class CommonDivisor, class Filter,
727           class StateTable>
728 class DeterminizeFstImpl : public DeterminizeFstImplBase<Arc> {
729  public:
730   using Label = typename Arc::Label;
731   using StateId = typename Arc::StateId;
732   using Weight = typename Arc::Weight;
733
734   using ToMapper = ToGallicMapper<Arc, G>;
735   using ToArc = typename ToMapper::ToArc;
736   using ToFst = ArcMapFst<Arc, ToArc, ToMapper>;
737   using FromMapper = FromGallicMapper<Arc, G>;
738   using FromFst = ArcMapFst<ToArc, Arc, FromMapper>;
739
740   using ToCommonDivisor = GallicCommonDivisor<Label, Weight, G, CommonDivisor>;
741   using ToFilter = typename Filter::template rebind<ToArc>::Other;
742   using ToFilterState = typename ToFilter::FilterState;
743   using ToStateTable =
744       typename StateTable::template rebind<ToArc, ToFilterState>::Other;
745   using FactorIterator = GallicFactor<Label, Weight, G>;
746
747   using FstImpl<Arc>::SetProperties;
748   using DeterminizeFstImplBase<Arc>::GetFst;
749   using CacheBaseImpl<CacheState<Arc>>::GetCacheGc;
750   using CacheBaseImpl<CacheState<Arc>>::GetCacheLimit;
751
752   DeterminizeFstImpl(
753       const Fst<Arc> &fst,
754       const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable> &opts)
755       : DeterminizeFstImplBase<Arc>(fst, opts),
756         delta_(opts.delta),
757         subsequential_label_(opts.subsequential_label),
758         increment_subsequential_label_(opts.increment_subsequential_label) {
759     if (opts.state_table) {
760       FSTERROR() << "DeterminizeFst: "
761                  << "A state table can not be passed with transducer input";
762       SetProperties(kError, kError);
763       return;
764     }
765     Init(GetFst(), opts.filter);
766   }
767
768   DeterminizeFstImpl(
769       const DeterminizeFstImpl<Arc, G, CommonDivisor, Filter, StateTable> &impl)
770       : DeterminizeFstImplBase<Arc>(impl),
771         delta_(impl.delta_),
772         subsequential_label_(impl.subsequential_label_),
773         increment_subsequential_label_(impl.increment_subsequential_label_) {
774     Init(GetFst(), nullptr);
775   }
776
777   DeterminizeFstImpl<Arc, G, CommonDivisor, Filter, StateTable> *Copy()
778       const override {
779     return new DeterminizeFstImpl<Arc, G, CommonDivisor, Filter, StateTable>(
780         *this);
781   }
782
783   uint64 Properties() const override { return Properties(kFstProperties); }
784
785   // Sets error if found, and returns other FST impl properties.
786   uint64 Properties(uint64 mask) const override {
787     if ((mask & kError) && (GetFst().Properties(kError, false) ||
788                             from_fst_->Properties(kError, false))) {
789       SetProperties(kError, kError);
790     }
791     return FstImpl<Arc>::Properties(mask);
792   }
793
794   StateId ComputeStart() override { return from_fst_->Start(); }
795
796   Weight ComputeFinal(StateId s) override { return from_fst_->Final(s); }
797
798   void Expand(StateId s) override {
799     for (ArcIterator<FromFst> aiter(*from_fst_, s); !aiter.Done();
800          aiter.Next()) {
801       CacheImpl<Arc>::PushArc(s, aiter.Value());
802     }
803     CacheImpl<Arc>::SetArcs(s);
804   }
805
806  private:
807   // Initialization of transducer determinization implementation, which is
808   // defined after DeterminizeFst since it calls it.
809   void Init(const Fst<Arc> &fst, Filter *filter);
810
811   float delta_;
812   Label subsequential_label_;
813   bool increment_subsequential_label_;
814   std::unique_ptr<FromFst> from_fst_;
815 };
816
817 }  // namespace internal
818
819 // Determinizes a weighted transducer. This version is a delayed
820 // FST. The result will be an equivalent FST that has the property
821 // that no state has two transitions with the same input label.
822 // For this algorithm, epsilon transitions are treated as regular
823 // symbols (cf. RmEpsilon).
824 //
825 // The transducer must be functional. The weights must be (weakly) left
826 // divisible (valid for TropicalWeight and LogWeight for instance) and be
827 // zero-sum-free if for all a, b: (Plus(a, b) == 0) => a = b = 0.
828 //
829 // Complexity:
830 //
831 //   Determinizable: exponential (polynomial in the size of the output).
832 //   Non-determinizable: does not terminate.
833 //
834 // The determinizable automata include all unweighted and all acyclic input.
835 //
836 // For more information, see:
837 //
838 // Mohri, M. 1997. Finite-state transducers in language and speech processing.
839 // Computational Linguistics 23(2): 269-311.
840 //
841 // This class attaches interface to implementation and handles reference
842 // counting, delegating most methods to ImplToFst.
843 template <class A>
844 class DeterminizeFst : public ImplToFst<internal::DeterminizeFstImplBase<A>> {
845  public:
846   using Arc = A;
847   using Label = typename Arc::Label;
848   using StateId = typename Arc::StateId;
849   using Weight = typename Arc::Weight;
850
851   using Store = DefaultCacheStore<Arc>;
852   using State = typename Store::State;
853   using Impl = internal::DeterminizeFstImplBase<Arc>;
854
855   friend class ArcIterator<DeterminizeFst<Arc>>;
856   friend class StateIterator<DeterminizeFst<Arc>>;
857
858   template <class B, GallicType G, class CommonDivisor, class Filter,
859             class StateTable>
860   friend class DeterminizeFstImpl;
861
862   explicit DeterminizeFst(const Fst<A> &fst)
863       : ImplToFst<Impl>(CreateImpl(fst)) {}
864
865   template <class CommonDivisor, class Filter, class StateTable>
866   DeterminizeFst(
867       const Fst<Arc> &fst,
868       const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>
869           &opts =
870               DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>())
871       : ImplToFst<Impl>(CreateImpl(fst, opts)) {}
872
873   // This acceptor-only version additionally computes the distance to final
874   // states in the output if provided with those distances for the input; this
875   // is useful for e.g., computing the k-shortest unique paths.
876   template <class CommonDivisor, class Filter, class StateTable>
877   DeterminizeFst(
878       const Fst<Arc> &fst, const std::vector<Weight> *in_dist,
879       std::vector<Weight> *out_dist,
880       const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>
881           &opts =
882               DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>())
883       : ImplToFst<Impl>(
884             std::make_shared<internal::DeterminizeFsaImpl<Arc, CommonDivisor,
885                                                           Filter, StateTable>>(
886                 fst, in_dist, out_dist, opts)) {
887     if (!fst.Properties(kAcceptor, true)) {
888       FSTERROR() << "DeterminizeFst: "
889                  << "Distance to final states computed for acceptors only";
890       GetMutableImpl()->SetProperties(kError, kError);
891     }
892   }
893
894   // See Fst<>::Copy() for doc.
895   DeterminizeFst(const DeterminizeFst<Arc> &fst, bool safe = false)
896       : ImplToFst<Impl>(safe ? std::shared_ptr<Impl>(fst.GetImpl()->Copy())
897                              : fst.GetSharedImpl()) {}
898
899   // Get a copy of this DeterminizeFst. See Fst<>::Copy() for further doc.
900   DeterminizeFst<Arc> *Copy(bool safe = false) const override {
901     return new DeterminizeFst<Arc>(*this, safe);
902   }
903
904   inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
905
906   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
907     GetMutableImpl()->InitArcIterator(s, data);
908   }
909
910  private:
911   using ImplToFst<Impl>::GetImpl;
912   using ImplToFst<Impl>::GetMutableImpl;
913
914   static std::shared_ptr<Impl> CreateImpl(const Fst<Arc> &fst) {
915     using D = DefaultCommonDivisor<Weight>;
916     using F = DefaultDeterminizeFilter<Arc>;
917     using T = DefaultDeterminizeStateTable<Arc, typename F::FilterState>;
918     const DeterminizeFstOptions<Arc, D, F, T> opts;
919     return CreateImpl(fst, opts);
920   }
921
922   template <class CommonDivisor, class Filter, class StateTable>
923   static std::shared_ptr<Impl> CreateImpl(
924       const Fst<Arc> &fst,
925       const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>
926           &opts) {
927     if (fst.Properties(kAcceptor, true)) {
928       // Calls implementation for acceptors.
929       return std::make_shared<
930           internal::DeterminizeFsaImpl<Arc, CommonDivisor, Filter, StateTable>>(
931           fst, nullptr, nullptr, opts);
932     } else if (opts.type == DETERMINIZE_DISAMBIGUATE) {
933       auto rv = std::make_shared<internal::DeterminizeFstImpl<
934           Arc, GALLIC_MIN, CommonDivisor, Filter, StateTable>>(fst, opts);
935       if (!(Weight::Properties() & kPath)) {
936         FSTERROR() << "DeterminizeFst: Weight needs to have the "
937                    << "path property to disambiguate output: "
938                    << Weight::Type();
939         rv->SetProperties(kError, kError);
940       }
941       // Calls disambiguating implementation for non-functional transducers.
942       return rv;
943     } else if (opts.type == DETERMINIZE_FUNCTIONAL) {
944       // Calls implementation for functional transducers.
945       return std::make_shared<internal::DeterminizeFstImpl<
946           Arc, GALLIC_RESTRICT, CommonDivisor, Filter, StateTable>>(fst, opts);
947     } else {  // opts.type == DETERMINIZE_NONFUNCTIONAL
948       // Calls implementation for non functional transducers;
949       return std::make_shared<internal::DeterminizeFstImpl<
950           Arc, GALLIC, CommonDivisor, Filter, StateTable>>(fst, opts);
951     }
952   }
953
954   DeterminizeFst &operator=(const DeterminizeFst &) = delete;
955 };
956
957 namespace internal {
958
959 // Initialization of transducer determinization implementation, which is defined
960 // after DeterminizeFst since it calls it.
961 template <class A, GallicType G, class D, class F, class T>
962 void DeterminizeFstImpl<A, G, D, F, T>::Init(const Fst<A> &fst, F *filter) {
963   // Mapper to an acceptor.
964   const ToFst to_fst(fst, ToMapper());
965   auto *to_filter = filter ? new ToFilter(to_fst, filter) : nullptr;
966   // This recursive call terminates since it is to a (non-recursive)
967   // different constructor.
968   const CacheOptions copts(GetCacheGc(), GetCacheLimit());
969   const DeterminizeFstOptions<ToArc, ToCommonDivisor, ToFilter, ToStateTable>
970       dopts(copts, delta_, 0, DETERMINIZE_FUNCTIONAL, false, to_filter);
971   // Uses acceptor-only constructor to avoid template recursion.
972   const DeterminizeFst<ToArc> det_fsa(to_fst, nullptr, nullptr, dopts);
973   // Mapper back to transducer.
974   const FactorWeightOptions<ToArc> fopts(
975       CacheOptions(true, 0), delta_, kFactorFinalWeights, subsequential_label_,
976       subsequential_label_, increment_subsequential_label_,
977       increment_subsequential_label_);
978   const FactorWeightFst<ToArc, FactorIterator> factored_fst(det_fsa, fopts);
979   from_fst_.reset(new FromFst(factored_fst, FromMapper(subsequential_label_)));
980 }
981
982 }  // namespace internal
983
984 // Specialization for DeterminizeFst.
985 template <class Arc>
986 class StateIterator<DeterminizeFst<Arc>>
987     : public CacheStateIterator<DeterminizeFst<Arc>> {
988  public:
989   explicit StateIterator(const DeterminizeFst<Arc> &fst)
990       : CacheStateIterator<DeterminizeFst<Arc>>(fst, fst.GetMutableImpl()) {}
991 };
992
993 // Specialization for DeterminizeFst.
994 template <class Arc>
995 class ArcIterator<DeterminizeFst<Arc>>
996     : public CacheArcIterator<DeterminizeFst<Arc>> {
997  public:
998   using StateId = typename Arc::StateId;
999
1000   ArcIterator(const DeterminizeFst<Arc> &fst, StateId s)
1001       : CacheArcIterator<DeterminizeFst<Arc>>(fst.GetMutableImpl(), s) {
1002     if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
1003   }
1004 };
1005
1006 template <class Arc>
1007 inline void DeterminizeFst<Arc>::InitStateIterator(
1008     StateIteratorData<Arc> *data) const {
1009   data->base = new StateIterator<DeterminizeFst<Arc>>(*this);
1010 }
1011
1012 // Useful aliases when using StdArc.
1013 using StdDeterminizeFst = DeterminizeFst<StdArc>;
1014
1015 template <class Arc>
1016 struct DeterminizeOptions {
1017   using Label = typename Arc::Label;
1018   using StateId = typename Arc::StateId;
1019   using Weight = typename Arc::Weight;
1020
1021   float delta;                // Quantization delta for subset weights.
1022   Weight weight_threshold;    // Pruning weight threshold.
1023   StateId state_threshold;    // Pruning state threshold.
1024   Label subsequential_label;  // Label used for residual final output.
1025   DeterminizeType type;
1026   bool increment_subsequential_label;  // When creating several subsequential
1027                                        // arcs at a given state, make their
1028                                        // label distinct by incrementation?
1029
1030   explicit DeterminizeOptions(float delta = kDelta,
1031                               Weight weight_threshold = Weight::Zero(),
1032                               StateId state_threshold = kNoStateId,
1033                               Label subsequential_label = 0,
1034                               DeterminizeType type = DETERMINIZE_FUNCTIONAL,
1035                               bool increment_subsequential_label = false)
1036       : delta(delta),
1037         weight_threshold(std::move(weight_threshold)),
1038         state_threshold(state_threshold),
1039         subsequential_label(subsequential_label),
1040         type(type),
1041         increment_subsequential_label(increment_subsequential_label) {}
1042 };
1043
1044 // Determinizes a weighted transducer. This version writes the
1045 // determinized Fst to an output MutableFst. The result will be an
1046 // equivalent FST that has the property that no state has two
1047 // transitions with the same input label. For this algorithm, epsilon
1048 // transitions are treated as regular symbols (cf. RmEpsilon).
1049 //
1050 // The transducer must be functional. The weights must be (weakly)
1051 // left divisible (valid for TropicalWeight and LogWeight).
1052 //
1053 // Complexity:
1054 //
1055 //   Determinizable: exponential (polynomial in the size of the output)
1056 //   Non-determinizable: does not terminate
1057 //
1058 // The determinizable automata include all unweighted and all acyclic input.
1059 template <class Arc>
1060 void Determinize(
1061     const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
1062     const DeterminizeOptions<Arc> &opts = DeterminizeOptions<Arc>()) {
1063   using Weight = typename Arc::Weight;
1064   DeterminizeFstOptions<Arc> nopts;
1065   nopts.delta = opts.delta;
1066   nopts.subsequential_label = opts.subsequential_label;
1067   nopts.type = opts.type;
1068   nopts.increment_subsequential_label = opts.increment_subsequential_label;
1069   nopts.gc_limit = 0;  // Caches only the last state for fastest copy.
1070   if (opts.weight_threshold != Weight::Zero() ||
1071       opts.state_threshold != kNoStateId) {
1072     if (ifst.Properties(kAcceptor, false)) {
1073       std::vector<Weight> idistance;
1074       std::vector<Weight> odistance;
1075       ShortestDistance(ifst, &idistance, true);
1076       DeterminizeFst<Arc> dfst(ifst, &idistance, &odistance, nopts);
1077       PruneOptions<Arc, AnyArcFilter<Arc>> popts(
1078           opts.weight_threshold, opts.state_threshold, AnyArcFilter<Arc>(),
1079           &odistance);
1080       Prune(dfst, ofst, popts);
1081     } else {
1082       *ofst = DeterminizeFst<Arc>(ifst, nopts);
1083       Prune(ofst, opts.weight_threshold, opts.state_threshold);
1084     }
1085   } else {
1086     *ofst = DeterminizeFst<Arc>(ifst, nopts);
1087   }
1088 }
1089
1090 }  // namespace fst
1091
1092 #endif  // FST_DETERMINIZE_H_