1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Functions and classes to determinize an FST.
6 #ifndef FST_DETERMINIZE_H_
7 #define FST_DETERMINIZE_H_
11 #include <forward_list>
18 #include <fst/arc-map.h>
19 #include <fst/bi-table.h>
20 #include <fst/cache.h>
21 #include <fst/factor-weight.h>
22 #include <fst/filter-state.h>
23 #include <fst/prune.h>
24 #include <fst/test-properties.h>
29 // Common divisors are used in determinization to compute transition weights.
30 // In the simplest case, it is the same as semiring Plus, but other choices
31 // permit more efficient determinization when the output contains strings.
33 // The default common divisor uses the semiring Plus.
35 struct DefaultCommonDivisor {
39 Weight operator()(const Weight &w1, const Weight &w2) const {
44 // The label common divisor for a (left) string semiring selects a single
45 // letter common prefix or the empty string. This is used in the
46 // determinization of output strings so that at most a single letter will
47 // appear in the output of a transtion.
48 template <typename Label, StringType S>
49 struct LabelCommonDivisor {
51 using Weight = StringWeight<Label, S>;
53 Weight operator()(const Weight &w1, const Weight &w2) const {
54 typename Weight::Iterator iter1(w1);
55 typename Weight::Iterator iter2(w2);
56 if (!(StringWeight<Label, S>::Properties() & kLeftSemiring)) {
57 FSTERROR() << "LabelCommonDivisor: Weight needs to be left semiring";
58 return Weight::NoWeight();
59 } else if (w1.Size() == 0 || w2.Size() == 0) {
61 } else if (w1 == Weight::Zero()) {
62 return Weight(iter2.Value());
63 } else if (w2 == Weight::Zero()) {
64 return Weight(iter1.Value());
65 } else if (iter1.Value() == iter2.Value()) {
66 return Weight(iter1.Value());
73 // The gallic common divisor uses the label common divisor on the string
74 // component and the common divisor on the weight component, which defaults to
75 // the default common divisor.
76 template <class Label, class W, GallicType G,
77 class CommonDivisor = DefaultCommonDivisor<W>>
78 class GallicCommonDivisor {
80 using Weight = GallicWeight<Label, W, G>;
82 Weight operator()(const Weight &w1, const Weight &w2) const {
83 return Weight(label_common_divisor_(w1.Value1(), w2.Value1()),
84 weight_common_divisor_(w1.Value2(), w2.Value2()));
88 LabelCommonDivisor<Label, GallicStringType(G)> label_common_divisor_;
89 CommonDivisor weight_common_divisor_;
92 // Specialization for general GALLIC weight.
93 template <class Label, class W, class CommonDivisor>
94 class GallicCommonDivisor<Label, W, GALLIC, CommonDivisor> {
96 using Weight = GallicWeight<Label, W, GALLIC>;
97 using GRWeight = GallicWeight<Label, W, GALLIC_RESTRICT>;
99 UnionWeightIterator<GRWeight, GallicUnionWeightOptions<Label, W>>;
101 Weight operator()(const Weight &w1, const Weight &w2) const {
102 auto weight = GRWeight::Zero();
103 for (Iterator iter(w1); !iter.Done(); iter.Next()) {
104 weight = common_divisor_(weight, iter.Value());
106 for (Iterator iter(w2); !iter.Done(); iter.Next()) {
107 weight = common_divisor_(weight, iter.Value());
109 return weight == GRWeight::Zero() ? Weight::Zero() : Weight(weight);
113 GallicCommonDivisor<Label, W, GALLIC_RESTRICT, CommonDivisor> common_divisor_;
118 // Represents an element in a subset
120 struct DeterminizeElement {
121 using StateId = typename Arc::StateId;
122 using Weight = typename Arc::Weight;
124 DeterminizeElement(StateId s, Weight weight)
125 : state_id(s), weight(std::move(weight)) {}
127 inline bool operator==(const DeterminizeElement<Arc> &element) const {
128 return state_id == element.state_id && weight == element.weight;
131 inline bool operator!=(const DeterminizeElement<Arc> &element) const {
132 return !(*this == element);
135 inline bool operator<(const DeterminizeElement<Arc> &element) const {
136 return state_id < element.state_id;
139 StateId state_id; // Input state ID.
140 Weight weight; // Residual weight.
143 // Represents a weighted subset and determinization filter state
144 template <typename A, typename FilterState>
145 struct DeterminizeStateTuple {
147 using Element = DeterminizeElement<Arc>;
148 using Subset = std::forward_list<Element>;
150 DeterminizeStateTuple() : filter_state(FilterState::NoState()) {}
152 inline bool operator==(
153 const DeterminizeStateTuple<Arc, FilterState> &tuple) const {
154 return (tuple.filter_state == filter_state) && (tuple.subset == subset);
157 inline bool operator!=(
158 const DeterminizeStateTuple<Arc, FilterState> &tuple) const {
159 return (tuple.filter_state != filter_state) || (tuple.subset != subset);
163 FilterState filter_state;
166 // Proto-transition for determinization.
167 template <class StateTuple>
168 struct DeterminizeArc {
169 using Arc = typename StateTuple::Arc;
170 using Label = typename Arc::Label;
171 using Weight = typename Arc::Weight;
174 : label(kNoLabel), weight(Weight::Zero()), dest_tuple(nullptr) {}
176 explicit DeterminizeArc(const Arc &arc)
177 : label(arc.ilabel), weight(Weight::Zero()), dest_tuple(new StateTuple) {}
179 Label label; // Arc label.
180 Weight weight; // Arc weight.
181 StateTuple *dest_tuple; // Destination subset and filter state.
184 } // namespace internal
186 // Determinization filters are used to compute destination state tuples based
187 // on the source tuple, transition, and destination element or on similar
188 // super-final transition information. The filter operates on a map between a
189 // label and the corresponding destination state tuples. It must define the map
190 // type LabelMap. The default filter is used for weighted determinization.
191 // A determinize filter for implementing weighted determinization.
193 class DefaultDeterminizeFilter {
195 using Label = typename Arc::Label;
196 using StateId = typename Arc::StateId;
197 using Weight = typename Arc::Weight;
199 using FilterState = CharFilterState;
200 using Element = internal::DeterminizeElement<Arc>;
201 using StateTuple = internal::DeterminizeStateTuple<Arc, FilterState>;
202 using LabelMap = std::map<Label, internal::DeterminizeArc<StateTuple>>;
204 // This is needed e.g. to go into the gallic domain for transducers.
207 using Other = DefaultDeterminizeFilter<A>;
210 explicit DefaultDeterminizeFilter(const Fst<Arc> &fst) : fst_(fst.Copy()) {}
212 // This is needed (e.g.) to go into the gallic domain for transducers.
213 // Ownership of the templated filter argument is given to this class.
214 template <class Filter>
215 DefaultDeterminizeFilter(const Fst<Arc> &fst, Filter *filter)
220 // Copy constructor; the FST can be passed if it has been deep-copied.
221 DefaultDeterminizeFilter(const DefaultDeterminizeFilter<Arc> &filter,
222 const Fst<Arc> *fst = nullptr)
223 : fst_(fst ? fst->Copy() : filter.fst_->Copy()) {}
225 FilterState Start() const { return FilterState(0); }
228 void SetState(StateId s, const StateTuple &tuple) {}
230 // Filters transition, possibly modifying label map. Returns true if arc is
231 // added to the label map.
232 bool FilterArc(const Arc &arc, const Element &src_element,
233 const Element &dest_element, LabelMap *label_map) const {
234 // Adds element to unique state tuple for arc label.
235 auto &det_arc = (*label_map)[arc.ilabel];
236 if (det_arc.label == kNoLabel) {
237 det_arc = internal::DeterminizeArc<StateTuple>(arc);
238 det_arc.dest_tuple->filter_state = FilterState(0);
240 det_arc.dest_tuple->subset.push_front(dest_element);
244 // Filters super-final transition, returning new final weight.
245 Weight FilterFinal(Weight weight, const Element &element) { return weight; }
247 static uint64 Properties(uint64 props) { return props; }
250 std::unique_ptr<Fst<Arc>> fst_;
253 // Determinization state table interface:
255 // template <class Arc, class FilterState>
256 // class DeterminizeStateTable {
258 // using StateId = typename Arc::StateId;
259 // using StateTuple = internal::DeterminizeStateTuple<Arc, FilterState>;
261 // // Required sub-class. This is needed (e.g.) to go into the gallic domain.
262 // template <class B, class G>
264 // using Other = DeterminizeStateTable<B, G>;
267 // // Required constuctor.
268 // DeterminizeStateTable();
270 // // Required copy constructor that does not copy state.
271 // DeterminizeStateTable(const DeterminizeStateTable<Arc, FilterState>
274 // // Looks up state ID by state tuple; if it doesn't exist, then adds it.
275 // // FindState takes ownership of the state tuple argument so that it
276 // // doesn't have to copy it if it creates a new state.
277 // StateId FindState(StateTuple *tuple);
279 // // Looks up state tuple by ID.
280 // const StateTuple *Tuple(StateId id) const;
283 // The default determinization state table based on the compact hash bi-table.
284 template <class Arc, class FilterState>
285 class DefaultDeterminizeStateTable {
287 using Label = typename Arc::Label;
288 using StateId = typename Arc::StateId;
289 using Weight = typename Arc::Weight;
291 using StateTuple = internal::DeterminizeStateTuple<Arc, FilterState>;
292 using Element = typename StateTuple::Element;
293 using Subset = typename StateTuple::Subset;
295 template <class B, class G>
297 using Other = DefaultDeterminizeStateTable<B, G>;
300 explicit DefaultDeterminizeStateTable(size_t table_size = 0)
301 : table_size_(table_size), tuples_(table_size_) {}
303 DefaultDeterminizeStateTable(
304 const DefaultDeterminizeStateTable<Arc, FilterState> &table)
305 : table_size_(table.table_size_), tuples_(table_size_) {}
307 ~DefaultDeterminizeStateTable() {
308 for (StateId s = 0; s < tuples_.Size(); ++s) delete tuples_.FindEntry(s);
311 // Finds the state corresponding to a state tuple. Only creates a new state if
312 // the tuple is not found. FindState takes ownership of the tuple argument so
313 // that it doesn't have to copy it if it creates a new state.
314 StateId FindState(StateTuple *tuple) {
315 const StateId ns = tuples_.Size();
316 const auto s = tuples_.FindId(tuple);
317 if (s != ns) delete tuple; // Tuple found.
321 const StateTuple *Tuple(StateId s) { return tuples_.FindEntry(s); }
324 // Comparison object for StateTuples.
325 class StateTupleEqual {
327 bool operator()(const StateTuple *tuple1, const StateTuple *tuple2) const {
328 return *tuple1 == *tuple2;
332 // Hash function for StateTuples.
333 class StateTupleKey {
335 size_t operator()(const StateTuple *tuple) const {
336 size_t h = tuple->filter_state.Hash();
337 for (auto it = tuple->subset.begin(); it != tuple->subset.end(); ++it) {
338 const size_t h1 = it->state_id;
339 static constexpr auto lshift = 5;
340 static constexpr auto rshift = CHAR_BIT * sizeof(size_t) - 5;
341 h ^= h << 1 ^ h1 << lshift ^ h1 >> rshift ^ it->weight.Hash();
348 CompactHashBiTable<StateId, StateTuple *, StateTupleKey, StateTupleEqual,
352 DefaultDeterminizeStateTable &operator=(
353 const DefaultDeterminizeStateTable &) = delete;
356 // Determinization type.
357 enum DeterminizeType {
358 // Input transducer is known to be functional (or error).
359 DETERMINIZE_FUNCTIONAL, // Input transducer is functional (error if not).
360 // Input transducer is not known to be functional.
361 DETERMINIZE_NONFUNCTIONAL,
362 // Input transducer is not known to be functional but only keep the min of
363 // of ambiguous outputs.
364 DETERMINIZE_DISAMBIGUATE
367 // Options for finite-state transducer determinization templated on the arc
368 // type, common divisor, the determinization filter and the state table.
369 // DeterminizeFst takes ownership of the determinization filter and state table,
372 class CommonDivisor = DefaultCommonDivisor<typename Arc::Weight>,
373 class Filter = DefaultDeterminizeFilter<Arc>,
375 DefaultDeterminizeStateTable<Arc, typename Filter::FilterState>>
376 struct DeterminizeFstOptions : public CacheOptions {
377 using Label = typename Arc::Label;
379 float delta; // Quantization delta for subset weights.
380 Label subsequential_label; // Label used for residual final output
381 // when producing subsequential transducers.
382 DeterminizeType type; // Determinization type.
383 bool increment_subsequential_label; // When creating several subsequential
384 // arcs at a given state, make their
385 // label distinct by incrementing.
386 Filter *filter; // Determinization filter;
387 // DeterminizeFst takes ownership.
388 StateTable *state_table; // Determinization state table;
389 // DeterminizeFst takes ownership.
391 explicit DeterminizeFstOptions(const CacheOptions &opts, float delta = kDelta,
392 Label subsequential_label = 0,
393 DeterminizeType type = DETERMINIZE_FUNCTIONAL,
394 bool increment_subsequential_label = false,
395 Filter *filter = nullptr,
396 StateTable *state_table = nullptr)
397 : CacheOptions(opts),
399 subsequential_label(subsequential_label),
401 increment_subsequential_label(increment_subsequential_label),
403 state_table(state_table) {}
405 explicit DeterminizeFstOptions(float delta = kDelta,
406 Label subsequential_label = 0,
407 DeterminizeType type = DETERMINIZE_FUNCTIONAL,
408 bool increment_subsequential_label = false,
409 Filter *filter = nullptr,
410 StateTable *state_table = nullptr)
412 subsequential_label(subsequential_label),
414 increment_subsequential_label(increment_subsequential_label),
416 state_table(state_table) {}
421 // Implementation of delayed DeterminizeFst. This base class is
422 // common to the variants that implement acceptor and transducer
425 class DeterminizeFstImplBase : public CacheImpl<Arc> {
427 using Label = typename Arc::Label;
428 using StateId = typename Arc::StateId;
429 using Weight = typename Arc::Weight;
431 using Store = DefaultCacheStore<Arc>;
432 using State = typename Store::State;
434 using FstImpl<Arc>::SetType;
435 using FstImpl<Arc>::SetProperties;
436 using FstImpl<Arc>::Properties;
437 using FstImpl<Arc>::SetInputSymbols;
438 using FstImpl<Arc>::SetOutputSymbols;
440 using CacheBaseImpl<CacheState<Arc>>::HasStart;
441 using CacheBaseImpl<CacheState<Arc>>::HasFinal;
442 using CacheBaseImpl<CacheState<Arc>>::HasArcs;
443 using CacheBaseImpl<CacheState<Arc>>::SetFinal;
444 using CacheBaseImpl<CacheState<Arc>>::SetStart;
446 template <class CommonDivisor, class Filter, class StateTable>
447 DeterminizeFstImplBase(
449 const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable> &opts)
450 : CacheImpl<Arc>(opts), fst_(fst.Copy()) {
451 SetType("determinize");
452 const auto iprops = fst.Properties(kFstProperties, false);
454 DeterminizeProperties(iprops, opts.subsequential_label != 0,
455 opts.type == DETERMINIZE_NONFUNCTIONAL
456 ? opts.increment_subsequential_label
458 SetProperties(Filter::Properties(dprops), kCopyProperties);
459 SetInputSymbols(fst.InputSymbols());
460 SetOutputSymbols(fst.OutputSymbols());
463 DeterminizeFstImplBase(const DeterminizeFstImplBase<Arc> &impl)
464 : CacheImpl<Arc>(impl), fst_(impl.fst_->Copy(true)) {
465 SetType("determinize");
466 SetProperties(impl.Properties(), kCopyProperties);
467 SetInputSymbols(impl.InputSymbols());
468 SetOutputSymbols(impl.OutputSymbols());
471 virtual DeterminizeFstImplBase<Arc> *Copy() const = 0;
475 const auto start = ComputeStart();
476 if (start != kNoStateId) SetStart(start);
478 return CacheImpl<Arc>::Start();
481 Weight Final(StateId s) {
482 if (!HasFinal(s)) SetFinal(s, ComputeFinal(s));
483 return CacheImpl<Arc>::Final(s);
486 virtual void Expand(StateId s) = 0;
488 size_t NumArcs(StateId s) {
489 if (!HasArcs(s)) Expand(s);
490 return CacheImpl<Arc>::NumArcs(s);
493 size_t NumInputEpsilons(StateId s) {
494 if (!HasArcs(s)) Expand(s);
495 return CacheImpl<Arc>::NumInputEpsilons(s);
498 size_t NumOutputEpsilons(StateId s) {
499 if (!HasArcs(s)) Expand(s);
500 return CacheImpl<Arc>::NumOutputEpsilons(s);
503 void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
504 if (!HasArcs(s)) Expand(s);
505 CacheImpl<Arc>::InitArcIterator(s, data);
508 virtual StateId ComputeStart() = 0;
510 virtual Weight ComputeFinal(StateId s) = 0;
512 const Fst<Arc> &GetFst() const { return *fst_; }
515 std::unique_ptr<const Fst<Arc>> fst_; // Input FST.
518 // Implementation of delayed determinization for weighted acceptors.
519 template <class Arc, class CommonDivisor, class Filter, class StateTable>
520 class DeterminizeFsaImpl : public DeterminizeFstImplBase<Arc> {
522 using Label = typename Arc::Label;
523 using StateId = typename Arc::StateId;
524 using Weight = typename Arc::Weight;
526 using FilterState = typename Filter::FilterState;
527 using StateTuple = internal::DeterminizeStateTuple<Arc, FilterState>;
528 using Element = typename StateTuple::Element;
529 using Subset = typename StateTuple::Subset;
530 using LabelMap = typename Filter::LabelMap;
532 using FstImpl<Arc>::SetProperties;
533 using DeterminizeFstImplBase<Arc>::GetFst;
534 using DeterminizeFstImplBase<Arc>::SetArcs;
537 const Fst<Arc> &fst, const std::vector<Weight> *in_dist,
538 std::vector<Weight> *out_dist,
539 const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable> &opts)
540 : DeterminizeFstImplBase<Arc>(fst, opts),
544 filter_(opts.filter ? opts.filter : new Filter(fst)),
545 state_table_(opts.state_table ? opts.state_table : new StateTable()) {
546 if (!fst.Properties(kAcceptor, true)) {
547 FSTERROR() << "DeterminizeFst: Argument not an acceptor";
548 SetProperties(kError, kError);
550 if (!(Weight::Properties() & kLeftSemiring)) {
551 FSTERROR() << "DeterminizeFst: Weight must be left distributive: "
553 SetProperties(kError, kError);
555 if (out_dist_) out_dist_->clear();
559 const DeterminizeFsaImpl<Arc, CommonDivisor, Filter, StateTable> &impl)
560 : DeterminizeFstImplBase<Arc>(impl),
564 filter_(new Filter(*impl.filter_, &GetFst())),
565 state_table_(new StateTable(*impl.state_table_)) {
566 if (impl.out_dist_) {
567 FSTERROR() << "DeterminizeFsaImpl: Cannot copy with out_dist vector";
568 SetProperties(kError, kError);
572 DeterminizeFsaImpl<Arc, CommonDivisor, Filter, StateTable> *Copy()
574 return new DeterminizeFsaImpl<Arc, CommonDivisor, Filter, StateTable>(
578 uint64 Properties() const override { return Properties(kFstProperties); }
580 // Sets error if found, and returns other FST impl properties.
581 uint64 Properties(uint64 mask) const override {
582 if ((mask & kError) && (GetFst().Properties(kError, false))) {
583 SetProperties(kError, kError);
585 return FstImpl<Arc>::Properties(mask);
588 StateId ComputeStart() override {
589 const auto s = GetFst().Start();
590 if (s == kNoStateId) return kNoStateId;
591 const Element element(s, Weight::One());
592 auto *tuple = new StateTuple;
593 tuple->subset.push_front(element);
594 tuple->filter_state = filter_->Start();
595 return FindState(tuple);
598 Weight ComputeFinal(StateId s) override {
599 const auto *tuple = state_table_->Tuple(s);
600 filter_->SetState(s, *tuple);
601 auto final_weight = Weight::Zero();
602 for (auto it = tuple->subset.begin(); it != tuple->subset.end(); ++it) {
603 const auto &element = *it;
606 Times(element.weight, GetFst().Final(element.state_id)));
607 final_weight = filter_->FilterFinal(final_weight, element);
608 if (!final_weight.Member()) SetProperties(kError, kError);
613 StateId FindState(StateTuple *tuple) {
614 const auto s = state_table_->FindState(tuple);
615 if (in_dist_ && out_dist_->size() <= s) {
616 out_dist_->push_back(ComputeDistance(tuple->subset));
621 // Computes distance from a state to the final states in the DFA given the
622 // distances in the NFA.
623 Weight ComputeDistance(const Subset &subset) {
624 auto outd = Weight::Zero();
625 for (auto it = subset.begin(); it != subset.end(); ++it) {
626 const auto &element = *it;
628 (element.state_id < in_dist_->size() ? (*in_dist_)[element.state_id]
630 outd = Plus(outd, Times(element.weight, ind));
635 // Computes the outgoing transitions from a state, creating new destination
637 void Expand(StateId s) override {
639 GetLabelMap(s, &label_map);
640 for (auto it = label_map.begin(); it != label_map.end(); ++it) {
641 AddArc(s, it->second);
647 using DetArc = internal::DeterminizeArc<StateTuple>;
649 // Constructs proto-determinization transition, including destination subset,
651 void GetLabelMap(StateId s, LabelMap *label_map) {
652 const auto *src_tuple = state_table_->Tuple(s);
653 filter_->SetState(s, *src_tuple);
654 for (auto it = src_tuple->subset.begin(); it != src_tuple->subset.end();
656 const auto &src_element = *it;
657 for (ArcIterator<Fst<Arc>> aiter(GetFst(), src_element.state_id);
658 !aiter.Done(); aiter.Next()) {
659 const auto &arc = aiter.Value();
660 const Element dest_element(arc.nextstate,
661 Times(src_element.weight, arc.weight));
662 filter_->FilterArc(arc, src_element, dest_element, label_map);
665 for (auto it = label_map->begin(); it != label_map->end(); ++it) {
666 NormArc(&it->second);
670 // Sorts subsets and removes duplicate elements, normalizing transition and
672 void NormArc(DetArc *det_arc) {
673 auto *dest_tuple = det_arc->dest_tuple;
674 dest_tuple->subset.sort();
675 auto piter = dest_tuple->subset.begin();
676 for (auto diter = dest_tuple->subset.begin();
677 diter != dest_tuple->subset.end();) {
678 auto &dest_element = *diter;
679 auto &prev_element = *piter;
680 // Computes arc weight.
681 det_arc->weight = common_divisor_(det_arc->weight, dest_element.weight);
682 if (piter != diter && dest_element.state_id == prev_element.state_id) {
683 // Found duplicate state: sums state weight and deletes duplicate.
684 prev_element.weight = Plus(prev_element.weight, dest_element.weight);
685 if (!prev_element.weight.Member()) SetProperties(kError, kError);
687 dest_tuple->subset.erase_after(piter);
693 // Divides out label weight from destination subset elements, quantizing to
694 // ensure comparisons are effective.
695 for (auto diter = dest_tuple->subset.begin();
696 diter != dest_tuple->subset.end(); ++diter) {
697 auto &dest_element = *diter;
698 dest_element.weight =
699 Divide(dest_element.weight, det_arc->weight, DIVIDE_LEFT);
700 dest_element.weight = dest_element.weight.Quantize(delta_);
704 // Adds an arc from state S to the destination state associated with state
705 // tuple in det_arc as created by GetLabelMap.
706 void AddArc(StateId s, const DetArc &det_arc) {
707 const Arc arc(det_arc.label, det_arc.label, det_arc.weight,
708 FindState(det_arc.dest_tuple));
709 CacheImpl<Arc>::PushArc(s, arc);
712 float delta_; // Quantization delta for weights.
713 const std::vector<Weight> *in_dist_; // Distance to final NFA states.
714 std::vector<Weight> *out_dist_; // Distance to final DFA states.
716 // FIXME(kbg): Ought to be static const?
717 CommonDivisor common_divisor_;
718 std::unique_ptr<Filter> filter_;
719 std::unique_ptr<StateTable> state_table_;
722 // Implementation of delayed determinization for transducers. Transducer
723 // determinization is implemented by mapping the input to the Gallic semiring as
724 // an acceptor whose weights contain the output strings and using acceptor
725 // determinization above to determinize that acceptor.
726 template <class Arc, GallicType G, class CommonDivisor, class Filter,
728 class DeterminizeFstImpl : public DeterminizeFstImplBase<Arc> {
730 using Label = typename Arc::Label;
731 using StateId = typename Arc::StateId;
732 using Weight = typename Arc::Weight;
734 using ToMapper = ToGallicMapper<Arc, G>;
735 using ToArc = typename ToMapper::ToArc;
736 using ToFst = ArcMapFst<Arc, ToArc, ToMapper>;
737 using FromMapper = FromGallicMapper<Arc, G>;
738 using FromFst = ArcMapFst<ToArc, Arc, FromMapper>;
740 using ToCommonDivisor = GallicCommonDivisor<Label, Weight, G, CommonDivisor>;
741 using ToFilter = typename Filter::template rebind<ToArc>::Other;
742 using ToFilterState = typename ToFilter::FilterState;
744 typename StateTable::template rebind<ToArc, ToFilterState>::Other;
745 using FactorIterator = GallicFactor<Label, Weight, G>;
747 using FstImpl<Arc>::SetProperties;
748 using DeterminizeFstImplBase<Arc>::GetFst;
749 using CacheBaseImpl<CacheState<Arc>>::GetCacheGc;
750 using CacheBaseImpl<CacheState<Arc>>::GetCacheLimit;
754 const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable> &opts)
755 : DeterminizeFstImplBase<Arc>(fst, opts),
757 subsequential_label_(opts.subsequential_label),
758 increment_subsequential_label_(opts.increment_subsequential_label) {
759 if (opts.state_table) {
760 FSTERROR() << "DeterminizeFst: "
761 << "A state table can not be passed with transducer input";
762 SetProperties(kError, kError);
765 Init(GetFst(), opts.filter);
769 const DeterminizeFstImpl<Arc, G, CommonDivisor, Filter, StateTable> &impl)
770 : DeterminizeFstImplBase<Arc>(impl),
772 subsequential_label_(impl.subsequential_label_),
773 increment_subsequential_label_(impl.increment_subsequential_label_) {
774 Init(GetFst(), nullptr);
777 DeterminizeFstImpl<Arc, G, CommonDivisor, Filter, StateTable> *Copy()
779 return new DeterminizeFstImpl<Arc, G, CommonDivisor, Filter, StateTable>(
783 uint64 Properties() const override { return Properties(kFstProperties); }
785 // Sets error if found, and returns other FST impl properties.
786 uint64 Properties(uint64 mask) const override {
787 if ((mask & kError) && (GetFst().Properties(kError, false) ||
788 from_fst_->Properties(kError, false))) {
789 SetProperties(kError, kError);
791 return FstImpl<Arc>::Properties(mask);
794 StateId ComputeStart() override { return from_fst_->Start(); }
796 Weight ComputeFinal(StateId s) override { return from_fst_->Final(s); }
798 void Expand(StateId s) override {
799 for (ArcIterator<FromFst> aiter(*from_fst_, s); !aiter.Done();
801 CacheImpl<Arc>::PushArc(s, aiter.Value());
803 CacheImpl<Arc>::SetArcs(s);
807 // Initialization of transducer determinization implementation, which is
808 // defined after DeterminizeFst since it calls it.
809 void Init(const Fst<Arc> &fst, Filter *filter);
812 Label subsequential_label_;
813 bool increment_subsequential_label_;
814 std::unique_ptr<FromFst> from_fst_;
817 } // namespace internal
819 // Determinizes a weighted transducer. This version is a delayed
820 // FST. The result will be an equivalent FST that has the property
821 // that no state has two transitions with the same input label.
822 // For this algorithm, epsilon transitions are treated as regular
823 // symbols (cf. RmEpsilon).
825 // The transducer must be functional. The weights must be (weakly) left
826 // divisible (valid for TropicalWeight and LogWeight for instance) and be
827 // zero-sum-free if for all a, b: (Plus(a, b) == 0) => a = b = 0.
831 // Determinizable: exponential (polynomial in the size of the output).
832 // Non-determinizable: does not terminate.
834 // The determinizable automata include all unweighted and all acyclic input.
836 // For more information, see:
838 // Mohri, M. 1997. Finite-state transducers in language and speech processing.
839 // Computational Linguistics 23(2): 269-311.
841 // This class attaches interface to implementation and handles reference
842 // counting, delegating most methods to ImplToFst.
844 class DeterminizeFst : public ImplToFst<internal::DeterminizeFstImplBase<A>> {
847 using Label = typename Arc::Label;
848 using StateId = typename Arc::StateId;
849 using Weight = typename Arc::Weight;
851 using Store = DefaultCacheStore<Arc>;
852 using State = typename Store::State;
853 using Impl = internal::DeterminizeFstImplBase<Arc>;
855 friend class ArcIterator<DeterminizeFst<Arc>>;
856 friend class StateIterator<DeterminizeFst<Arc>>;
858 template <class B, GallicType G, class CommonDivisor, class Filter,
860 friend class DeterminizeFstImpl;
862 explicit DeterminizeFst(const Fst<A> &fst)
863 : ImplToFst<Impl>(CreateImpl(fst)) {}
865 template <class CommonDivisor, class Filter, class StateTable>
868 const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>
870 DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>())
871 : ImplToFst<Impl>(CreateImpl(fst, opts)) {}
873 // This acceptor-only version additionally computes the distance to final
874 // states in the output if provided with those distances for the input; this
875 // is useful for e.g., computing the k-shortest unique paths.
876 template <class CommonDivisor, class Filter, class StateTable>
878 const Fst<Arc> &fst, const std::vector<Weight> *in_dist,
879 std::vector<Weight> *out_dist,
880 const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>
882 DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>())
884 std::make_shared<internal::DeterminizeFsaImpl<Arc, CommonDivisor,
885 Filter, StateTable>>(
886 fst, in_dist, out_dist, opts)) {
887 if (!fst.Properties(kAcceptor, true)) {
888 FSTERROR() << "DeterminizeFst: "
889 << "Distance to final states computed for acceptors only";
890 GetMutableImpl()->SetProperties(kError, kError);
894 // See Fst<>::Copy() for doc.
895 DeterminizeFst(const DeterminizeFst<Arc> &fst, bool safe = false)
896 : ImplToFst<Impl>(safe ? std::shared_ptr<Impl>(fst.GetImpl()->Copy())
897 : fst.GetSharedImpl()) {}
899 // Get a copy of this DeterminizeFst. See Fst<>::Copy() for further doc.
900 DeterminizeFst<Arc> *Copy(bool safe = false) const override {
901 return new DeterminizeFst<Arc>(*this, safe);
904 inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
906 void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
907 GetMutableImpl()->InitArcIterator(s, data);
911 using ImplToFst<Impl>::GetImpl;
912 using ImplToFst<Impl>::GetMutableImpl;
914 static std::shared_ptr<Impl> CreateImpl(const Fst<Arc> &fst) {
915 using D = DefaultCommonDivisor<Weight>;
916 using F = DefaultDeterminizeFilter<Arc>;
917 using T = DefaultDeterminizeStateTable<Arc, typename F::FilterState>;
918 const DeterminizeFstOptions<Arc, D, F, T> opts;
919 return CreateImpl(fst, opts);
922 template <class CommonDivisor, class Filter, class StateTable>
923 static std::shared_ptr<Impl> CreateImpl(
925 const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>
927 if (fst.Properties(kAcceptor, true)) {
928 // Calls implementation for acceptors.
929 return std::make_shared<
930 internal::DeterminizeFsaImpl<Arc, CommonDivisor, Filter, StateTable>>(
931 fst, nullptr, nullptr, opts);
932 } else if (opts.type == DETERMINIZE_DISAMBIGUATE) {
933 auto rv = std::make_shared<internal::DeterminizeFstImpl<
934 Arc, GALLIC_MIN, CommonDivisor, Filter, StateTable>>(fst, opts);
935 if (!(Weight::Properties() & kPath)) {
936 FSTERROR() << "DeterminizeFst: Weight needs to have the "
937 << "path property to disambiguate output: "
939 rv->SetProperties(kError, kError);
941 // Calls disambiguating implementation for non-functional transducers.
943 } else if (opts.type == DETERMINIZE_FUNCTIONAL) {
944 // Calls implementation for functional transducers.
945 return std::make_shared<internal::DeterminizeFstImpl<
946 Arc, GALLIC_RESTRICT, CommonDivisor, Filter, StateTable>>(fst, opts);
947 } else { // opts.type == DETERMINIZE_NONFUNCTIONAL
948 // Calls implementation for non functional transducers;
949 return std::make_shared<internal::DeterminizeFstImpl<
950 Arc, GALLIC, CommonDivisor, Filter, StateTable>>(fst, opts);
954 DeterminizeFst &operator=(const DeterminizeFst &) = delete;
959 // Initialization of transducer determinization implementation, which is defined
960 // after DeterminizeFst since it calls it.
961 template <class A, GallicType G, class D, class F, class T>
962 void DeterminizeFstImpl<A, G, D, F, T>::Init(const Fst<A> &fst, F *filter) {
963 // Mapper to an acceptor.
964 const ToFst to_fst(fst, ToMapper());
965 auto *to_filter = filter ? new ToFilter(to_fst, filter) : nullptr;
966 // This recursive call terminates since it is to a (non-recursive)
967 // different constructor.
968 const CacheOptions copts(GetCacheGc(), GetCacheLimit());
969 const DeterminizeFstOptions<ToArc, ToCommonDivisor, ToFilter, ToStateTable>
970 dopts(copts, delta_, 0, DETERMINIZE_FUNCTIONAL, false, to_filter);
971 // Uses acceptor-only constructor to avoid template recursion.
972 const DeterminizeFst<ToArc> det_fsa(to_fst, nullptr, nullptr, dopts);
973 // Mapper back to transducer.
974 const FactorWeightOptions<ToArc> fopts(
975 CacheOptions(true, 0), delta_, kFactorFinalWeights, subsequential_label_,
976 subsequential_label_, increment_subsequential_label_,
977 increment_subsequential_label_);
978 const FactorWeightFst<ToArc, FactorIterator> factored_fst(det_fsa, fopts);
979 from_fst_.reset(new FromFst(factored_fst, FromMapper(subsequential_label_)));
982 } // namespace internal
984 // Specialization for DeterminizeFst.
986 class StateIterator<DeterminizeFst<Arc>>
987 : public CacheStateIterator<DeterminizeFst<Arc>> {
989 explicit StateIterator(const DeterminizeFst<Arc> &fst)
990 : CacheStateIterator<DeterminizeFst<Arc>>(fst, fst.GetMutableImpl()) {}
993 // Specialization for DeterminizeFst.
995 class ArcIterator<DeterminizeFst<Arc>>
996 : public CacheArcIterator<DeterminizeFst<Arc>> {
998 using StateId = typename Arc::StateId;
1000 ArcIterator(const DeterminizeFst<Arc> &fst, StateId s)
1001 : CacheArcIterator<DeterminizeFst<Arc>>(fst.GetMutableImpl(), s) {
1002 if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
1006 template <class Arc>
1007 inline void DeterminizeFst<Arc>::InitStateIterator(
1008 StateIteratorData<Arc> *data) const {
1009 data->base = new StateIterator<DeterminizeFst<Arc>>(*this);
1012 // Useful aliases when using StdArc.
1013 using StdDeterminizeFst = DeterminizeFst<StdArc>;
1015 template <class Arc>
1016 struct DeterminizeOptions {
1017 using Label = typename Arc::Label;
1018 using StateId = typename Arc::StateId;
1019 using Weight = typename Arc::Weight;
1021 float delta; // Quantization delta for subset weights.
1022 Weight weight_threshold; // Pruning weight threshold.
1023 StateId state_threshold; // Pruning state threshold.
1024 Label subsequential_label; // Label used for residual final output.
1025 DeterminizeType type;
1026 bool increment_subsequential_label; // When creating several subsequential
1027 // arcs at a given state, make their
1028 // label distinct by incrementation?
1030 explicit DeterminizeOptions(float delta = kDelta,
1031 Weight weight_threshold = Weight::Zero(),
1032 StateId state_threshold = kNoStateId,
1033 Label subsequential_label = 0,
1034 DeterminizeType type = DETERMINIZE_FUNCTIONAL,
1035 bool increment_subsequential_label = false)
1037 weight_threshold(std::move(weight_threshold)),
1038 state_threshold(state_threshold),
1039 subsequential_label(subsequential_label),
1041 increment_subsequential_label(increment_subsequential_label) {}
1044 // Determinizes a weighted transducer. This version writes the
1045 // determinized Fst to an output MutableFst. The result will be an
1046 // equivalent FST that has the property that no state has two
1047 // transitions with the same input label. For this algorithm, epsilon
1048 // transitions are treated as regular symbols (cf. RmEpsilon).
1050 // The transducer must be functional. The weights must be (weakly)
1051 // left divisible (valid for TropicalWeight and LogWeight).
1055 // Determinizable: exponential (polynomial in the size of the output)
1056 // Non-determinizable: does not terminate
1058 // The determinizable automata include all unweighted and all acyclic input.
1059 template <class Arc>
1061 const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
1062 const DeterminizeOptions<Arc> &opts = DeterminizeOptions<Arc>()) {
1063 using Weight = typename Arc::Weight;
1064 DeterminizeFstOptions<Arc> nopts;
1065 nopts.delta = opts.delta;
1066 nopts.subsequential_label = opts.subsequential_label;
1067 nopts.type = opts.type;
1068 nopts.increment_subsequential_label = opts.increment_subsequential_label;
1069 nopts.gc_limit = 0; // Caches only the last state for fastest copy.
1070 if (opts.weight_threshold != Weight::Zero() ||
1071 opts.state_threshold != kNoStateId) {
1072 if (ifst.Properties(kAcceptor, false)) {
1073 std::vector<Weight> idistance;
1074 std::vector<Weight> odistance;
1075 ShortestDistance(ifst, &idistance, true);
1076 DeterminizeFst<Arc> dfst(ifst, &idistance, &odistance, nopts);
1077 PruneOptions<Arc, AnyArcFilter<Arc>> popts(
1078 opts.weight_threshold, opts.state_threshold, AnyArcFilter<Arc>(),
1080 Prune(dfst, ofst, popts);
1082 *ofst = DeterminizeFst<Arc>(ifst, nopts);
1083 Prune(ofst, opts.weight_threshold, opts.state_threshold);
1086 *ofst = DeterminizeFst<Arc>(ifst, nopts);
1092 #endif // FST_DETERMINIZE_H_