1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Functions and classes to minimize an FST.
6 #ifndef FST_MINIMIZE_H_
7 #define FST_MINIMIZE_H_
19 #include <fst/arcsort.h>
20 #include <fst/connect.h>
21 #include <fst/dfs-visit.h>
22 #include <fst/encode.h>
23 #include <fst/factor-weight.h>
25 #include <fst/mutable-fst.h>
26 #include <fst/partition.h>
28 #include <fst/queue.h>
29 #include <fst/reverse.h>
30 #include <fst/state-map.h>
36 // Comparator for creating partition.
38 class StateComparator {
40 using StateId = typename Arc::StateId;
41 using Weight = typename Arc::Weight;
43 StateComparator(const Fst<Arc> &fst, const Partition<StateId> &partition)
44 : fst_(fst), partition_(partition) {}
46 // Compares state x with state y based on sort criteria.
47 bool operator()(const StateId x, const StateId y) const {
48 // Checks for final state equivalence.
49 const auto xfinal = fst_.Final(x).Hash();
50 const auto yfinal = fst_.Final(y).Hash();
51 if (xfinal < yfinal) {
53 } else if (xfinal > yfinal) {
56 // Checks for number of arcs.
57 if (fst_.NumArcs(x) < fst_.NumArcs(y)) return true;
58 if (fst_.NumArcs(x) > fst_.NumArcs(y)) return false;
59 // If the number of arcs are equal, checks for arc match.
60 for (ArcIterator<Fst<Arc>> aiter1(fst_, x), aiter2(fst_, y);
61 !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) {
62 const auto &arc1 = aiter1.Value();
63 const auto &arc2 = aiter2.Value();
64 if (arc1.ilabel < arc2.ilabel) return true;
65 if (arc1.ilabel > arc2.ilabel) return false;
66 if (partition_.ClassId(arc1.nextstate) <
67 partition_.ClassId(arc2.nextstate))
69 if (partition_.ClassId(arc1.nextstate) >
70 partition_.ClassId(arc2.nextstate))
78 const Partition<StateId> &partition_;
81 // Computes equivalence classes for cyclic unweighted acceptors. For cyclic
82 // minimization we use the classic Hopcroft minimization algorithm, which has
83 // complexity O(E log V) where E is the number of arcs and V is the number of
86 // For more information, see:
88 // Hopcroft, J. 1971. An n Log n algorithm for minimizing states in a finite
89 // automaton. Ms, Stanford University.
91 // Note: the original presentation of the paper was for a finite automaton (==
92 // deterministic, unweighted acceptor), but we also apply it to the
93 // nondeterministic case, where it is also applicable as long as the semiring is
94 // idempotent (if the semiring is not idempotent, there are some complexities
95 // in keeping track of the weight when there are multiple arcs to states that
96 // will be merged, and we don't deal with this).
97 template <class Arc, class Queue>
98 class CyclicMinimizer {
100 using Label = typename Arc::Label;
101 using StateId = typename Arc::StateId;
102 using ClassId = typename Arc::StateId;
103 using Weight = typename Arc::Weight;
104 using RevArc = ReverseArc<Arc>;
106 explicit CyclicMinimizer(const ExpandedFst<Arc> &fst) {
111 const Partition<StateId> &GetPartition() const { return P_; }
114 // StateILabelHasher is a hashing object that computes a hash-function
115 // of an FST state that depends only on the set of ilabels on arcs leaving
116 // the state [note: it assumes that the arcs are ilabel-sorted].
117 // In order to work correctly for non-deterministic automata, multiple
118 // instances of the same ilabel count the same as a single instance.
119 class StateILabelHasher {
121 explicit StateILabelHasher(const Fst<Arc> &fst) : fst_(fst) {}
123 using Label = typename Arc::Label;
124 using StateId = typename Arc::StateId;
126 size_t operator()(const StateId s) {
127 const size_t p1 = 7603;
128 const size_t p2 = 433024223;
130 size_t current_ilabel = kNoLabel;
131 for (ArcIterator<Fst<Arc>> aiter(fst_, s); !aiter.Done(); aiter.Next()) {
132 Label this_ilabel = aiter.Value().ilabel;
133 if (this_ilabel != current_ilabel) { // Ignores repeats.
134 result = p1 * result + this_ilabel;
135 current_ilabel = this_ilabel;
142 const Fst<Arc> &fst_;
145 class ArcIterCompare {
147 explicit ArcIterCompare(const Partition<StateId> &partition)
148 : partition_(partition) {}
150 ArcIterCompare(const ArcIterCompare &comp) : partition_(comp.partition_) {}
152 // Compares two iterators based on their input labels.
153 bool operator()(const ArcIterator<Fst<RevArc>> *x,
154 const ArcIterator<Fst<RevArc>> *y) const {
155 const auto &xarc = x->Value();
156 const auto &yarc = y->Value();
157 return xarc.ilabel > yarc.ilabel;
161 const Partition<StateId> &partition_;
165 std::priority_queue<ArcIterator<Fst<RevArc>> *,
166 std::vector<ArcIterator<Fst<RevArc>> *>,
170 // Prepartitions the space into equivalence classes. We ensure that final and
171 // non-final states always go into different equivalence classes, and we use
172 // class StateILabelHasher to make sure that most of the time, states with
173 // different sets of ilabels on arcs leaving them, go to different partitions.
174 // Note: for the O(n) guarantees we don't rely on the goodness of this
175 // hashing function---it just provides a bonus speedup.
176 void PrePartition(const ExpandedFst<Arc> &fst) {
177 VLOG(5) << "PrePartition";
178 StateId next_class = 0;
179 auto num_states = fst.NumStates();
180 // Allocates a temporary vector to store the initial class mappings, so that
181 // we can allocate the classes all at once.
182 std::vector<StateId> state_to_initial_class(num_states);
184 // We maintain two maps from hash-value to class---one for final states
185 // (final-prob == One()) and one for non-final states
186 // (final-prob == Zero()). We are processing unweighted acceptors, so the
187 // are the only two possible values.
188 using HashToClassMap = std::unordered_map<size_t, StateId>;
189 HashToClassMap hash_to_class_nonfinal;
190 HashToClassMap hash_to_class_final;
191 StateILabelHasher hasher(fst);
192 for (StateId s = 0; s < num_states; ++s) {
193 size_t hash = hasher(s);
194 HashToClassMap &this_map =
195 (fst.Final(s) != Weight::Zero() ? hash_to_class_final
196 : hash_to_class_nonfinal);
197 // Avoids two map lookups by using 'insert' instead of 'find'.
198 auto p = this_map.insert(std::make_pair(hash, next_class));
199 state_to_initial_class[s] = p.second ? next_class++ : p.first->second;
201 // Lets the unordered_maps go out of scope before we allocate the classes,
202 // to reduce the maximum amount of memory used.
204 P_.AllocateClasses(next_class);
205 for (StateId s = 0; s < num_states; ++s) {
206 P_.Add(s, state_to_initial_class[s]);
208 for (StateId c = 0; c < next_class; ++c) L_.Enqueue(c);
209 VLOG(5) << "Initial Partition: " << P_.NumClasses();
212 // Creates inverse transition Tr_ = rev(fst), loops over states in FST and
213 // splits on final, creating two blocks in the partition corresponding to
215 void Initialize(const ExpandedFst<Arc> &fst) {
218 ILabelCompare<RevArc> ilabel_comp;
219 ArcSort(&Tr_, ilabel_comp);
220 // Tells the partition how many elements to allocate. The first state in
221 // Tr_ is super-final state.
222 P_.Initialize(Tr_.NumStates() - 1);
223 // Prepares initial partition.
225 // Allocates arc iterator queue.
226 ArcIterCompare comp(P_);
227 aiter_queue_.reset(new ArcIterQueue(comp));
229 // Partitions all classes with destination C.
230 void Split(ClassId C) {
231 // Prepares priority queue: opens arc iterator for each state in C, and
232 // inserts into priority queue.
233 for (PartitionIterator<StateId> siter(P_, C); !siter.Done(); siter.Next()) {
234 StateId s = siter.Value();
235 if (Tr_.NumArcs(s + 1)) {
236 aiter_queue_->push(new ArcIterator<Fst<RevArc>>(Tr_, s + 1));
239 // Now pops arc iterator from queue, splits entering equivalence class, and
240 // re-inserts updated iterator into queue.
241 Label prev_label = -1;
242 while (!aiter_queue_->empty()) {
243 std::unique_ptr<ArcIterator<Fst<RevArc>>> aiter(aiter_queue_->top());
245 if (aiter->Done()) continue;
246 const auto &arc = aiter->Value();
247 auto from_state = aiter->Value().nextstate - 1;
248 auto from_label = arc.ilabel;
249 if (prev_label != from_label) P_.FinalizeSplit(&L_);
250 auto from_class = P_.ClassId(from_state);
251 if (P_.ClassSize(from_class) > 1) P_.SplitOn(from_state);
252 prev_label = from_label;
254 if (!aiter->Done()) aiter_queue_->push(aiter.release());
256 P_.FinalizeSplit(&L_);
259 // Main loop for Hopcroft minimization.
260 void Compute(const Fst<Arc> &fst) {
261 // Processes active classes (FIFO, or FILO).
262 while (!L_.Empty()) {
263 const auto C = L_.Head();
265 Split(C); // Splits on C, all labels in C.
270 // Partioning of states into equivalence classes.
271 Partition<StateId> P_;
272 // Set of active classes to be processed in partition P.
274 // Reverses transition function.
275 VectorFst<RevArc> Tr_;
276 // Priority queue of open arc iterators for all states in the splitter
277 // equivalence class.
278 std::unique_ptr<ArcIterQueue> aiter_queue_;
281 // Computes equivalence classes for acyclic FST.
287 // where E is the number of arcs.
289 // For more information, see:
291 // Revuz, D. 1992. Minimization of acyclic deterministic automata in linear
292 // time. Theoretical Computer Science 92(1): 181-189.
294 class AcyclicMinimizer {
296 using Label = typename Arc::Label;
297 using StateId = typename Arc::StateId;
298 using ClassId = typename Arc::StateId;
299 using Weight = typename Arc::Weight;
301 explicit AcyclicMinimizer(const ExpandedFst<Arc> &fst) {
306 const Partition<StateId> &GetPartition() { return partition_; }
309 // DFS visitor to compute the height (distance) to final state.
310 class HeightVisitor {
312 HeightVisitor() : max_height_(0), num_states_(0) {}
314 // Invoked before DFS visit.
315 void InitVisit(const Fst<Arc> &fst) {}
317 // Invoked when state is discovered (2nd arg is DFS tree root).
318 bool InitState(StateId s, StateId root) {
319 // Extends height array and initialize height (distance) to 0.
320 for (StateId i = height_.size(); i <= s; ++i) height_.push_back(-1);
321 if (s >= num_states_) num_states_ = s + 1;
325 // Invoked when tree arc examined (to undiscovered state).
326 bool TreeArc(StateId s, const Arc &arc) { return true; }
328 // Invoked when back arc examined (to unfinished state).
329 bool BackArc(StateId s, const Arc &arc) { return true; }
331 // Invoked when forward or cross arc examined (to finished state).
332 bool ForwardOrCrossArc(StateId s, const Arc &arc) {
333 if (height_[arc.nextstate] + 1 > height_[s]) {
334 height_[s] = height_[arc.nextstate] + 1;
339 // Invoked when state finished (parent is kNoStateId for tree root).
340 void FinishState(StateId s, StateId parent, const Arc *parent_arc) {
341 if (height_[s] == -1) height_[s] = 0;
342 const auto h = height_[s] + 1;
344 if (h > height_[parent]) height_[parent] = h;
345 if (h > max_height_) max_height_ = h;
349 // Invoked after DFS visit.
350 void FinishVisit() {}
352 size_t max_height() const { return max_height_; }
354 const std::vector<StateId> &height() const { return height_; }
356 size_t num_states() const { return num_states_; }
359 std::vector<StateId> height_;
365 // Cluster states according to height (distance to final state)
366 void Initialize(const Fst<Arc> &fst) {
367 // Computes height (distance to final state).
368 HeightVisitor hvisitor;
369 DfsVisit(fst, &hvisitor);
370 // Creates initial partition based on height.
371 partition_.Initialize(hvisitor.num_states());
372 partition_.AllocateClasses(hvisitor.max_height() + 1);
373 const auto &hstates = hvisitor.height();
374 for (StateId s = 0; s < hstates.size(); ++s) partition_.Add(s, hstates[s]);
377 // Refines states based on arc sort (out degree, arc equivalence).
378 void Refine(const Fst<Arc> &fst) {
379 using EquivalenceMap = std::map<StateId, StateId, StateComparator<Arc>>;
380 StateComparator<Arc> comp(fst, partition_);
381 // Starts with tail (height = 0).
382 auto height = partition_.NumClasses();
383 for (StateId h = 0; h < height; ++h) {
384 EquivalenceMap equiv_classes(comp);
385 // Sorts states within equivalence class.
386 PartitionIterator<StateId> siter(partition_, h);
387 equiv_classes[siter.Value()] = h;
388 for (siter.Next(); !siter.Done(); siter.Next()) {
390 equiv_classes.insert(std::make_pair(siter.Value(), kNoStateId));
391 if (insert_result.second) {
392 insert_result.first->second = partition_.AddClass();
395 // Creates refined partition.
396 for (siter.Reset(); !siter.Done();) {
397 const auto s = siter.Value();
398 const auto old_class = partition_.ClassId(s);
399 const auto new_class = equiv_classes[s];
400 // A move operation can invalidate the iterator, so we first update
401 // the iterator to the next element before we move the current element
404 if (old_class != new_class) partition_.Move(s, new_class);
410 Partition<StateId> partition_;
413 // Given a partition and a Mutable FST, merges states of Fst in place (i.e.,
414 // destructively). Merging works by taking the first state in a class of the
415 // partition to be the representative state for the class. Each arc is then
416 // reconnected to this state. All states in the class are merged by adding
417 // their arcs to the representative state.
419 void MergeStates(const Partition<typename Arc::StateId> &partition,
420 MutableFst<Arc> *fst) {
421 using StateId = typename Arc::StateId;
422 std::vector<StateId> state_map(partition.NumClasses());
423 for (StateId i = 0; i < partition.NumClasses(); ++i) {
424 PartitionIterator<StateId> siter(partition, i);
425 state_map[i] = siter.Value(); // First state in partition.
427 // Relabels destination states.
428 for (StateId c = 0; c < partition.NumClasses(); ++c) {
429 for (PartitionIterator<StateId> siter(partition, c); !siter.Done();
431 const auto s = siter.Value();
432 for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
434 auto arc = aiter.Value();
435 arc.nextstate = state_map[partition.ClassId(arc.nextstate)];
436 if (s == state_map[c]) { // For the first state, just sets destination.
439 fst->AddArc(state_map[c], arc);
444 fst->SetStart(state_map[partition.ClassId(fst->Start())]);
449 void AcceptorMinimize(MutableFst<Arc> *fst,
450 bool allow_acyclic_minimization = true) {
451 if (!(fst->Properties(kAcceptor | kUnweighted, true) ==
452 (kAcceptor | kUnweighted))) {
453 FSTERROR() << "FST is not an unweighted acceptor";
454 fst->SetProperties(kError, kError);
457 // Connects FST before minimization, handles disconnected states.
459 if (fst->NumStates() == 0) return;
460 if (allow_acyclic_minimization && fst->Properties(kAcyclic, true)) {
461 // Acyclic minimization (Revuz).
462 VLOG(2) << "Acyclic minimization";
463 ArcSort(fst, ILabelCompare<Arc>());
464 AcyclicMinimizer<Arc> minimizer(*fst);
465 MergeStates(minimizer.GetPartition(), fst);
467 // Either the FST has cycles, or it's generated from non-deterministic input
468 // (which the Revuz algorithm can't handle), so use the cyclic minimization
469 // algorithm of Hopcroft.
470 VLOG(2) << "Cyclic minimization";
471 CyclicMinimizer<Arc, LifoQueue<typename Arc::StateId>> minimizer(*fst);
472 MergeStates(minimizer.GetPartition(), fst);
474 // Merges in appropriate semiring
475 ArcUniqueMapper<Arc> mapper(*fst);
476 StateMap(fst, mapper);
479 } // namespace internal
481 // In place minimization of deterministic weighted automata and transducers,
482 // and also non-deterministic ones if they use an idempotent semiring.
483 // For transducers, if the 'sfst' argument is not null, the algorithm
484 // produces a compact factorization of the minimal transducer.
486 // In the acyclic deterministic case, we use an algorithm from Revuz that is
487 // linear in the number of arcs (edges) in the machine.
489 // In the cyclic or non-deterministic case, we use the classical Hopcroft
490 // minimization (which was presented for the deterministic case but which
491 // also works for non-deterministic FSTs); this has complexity O(e log v).
494 void Minimize(MutableFst<Arc> *fst, MutableFst<Arc> *sfst = nullptr,
495 float delta = kDelta, bool allow_nondet = false) {
496 using Weight = typename Arc::Weight;
497 const auto props = fst->Properties(
498 kAcceptor | kIDeterministic | kWeighted | kUnweighted, true);
499 bool allow_acyclic_minimization;
500 if (props & kIDeterministic) {
501 allow_acyclic_minimization = true;
503 // Our approach to minimization of non-deterministic FSTs will only work in
504 // idempotent semirings---for non-deterministic inputs, a state could have
505 // multiple transitions to states that will get merged, and we'd have to
506 // sum their weights. The algorithm doesn't handle that.
507 if (!(Weight::Properties() & kIdempotent)) {
508 fst->SetProperties(kError, kError);
509 FSTERROR() << "Cannot minimize a non-deterministic FST over a "
510 "non-idempotent semiring";
512 } else if (!allow_nondet) {
513 fst->SetProperties(kError, kError);
514 FSTERROR() << "Refusing to minimize a non-deterministic FST with "
515 << "allow_nondet = false";
518 // The Revuz algorithm won't work for nondeterministic inputs, so if the
519 // input is nondeterministic, we'll have to pass a bool saying not to use
520 // that algorithm. We check at this level rather than in AcceptorMinimize(),
521 // because it's possible that the FST at this level could be deterministic,
522 // but a harmless type of non-determinism could be introduced by Encode()
523 // (thanks to kEncodeWeights, if the FST has epsilons and has a final
524 // weight with weights equal to some epsilon arc.)
525 allow_acyclic_minimization = false;
527 if (!(props & kAcceptor)) { // Weighted transducer.
528 VectorFst<GallicArc<Arc, GALLIC_LEFT>> gfst;
529 ArcMap(*fst, &gfst, ToGallicMapper<Arc, GALLIC_LEFT>());
531 gfst.SetProperties(kAcceptor, kAcceptor);
532 Push(&gfst, REWEIGHT_TO_INITIAL, delta);
533 ArcMap(&gfst, QuantizeMapper<GallicArc<Arc, GALLIC_LEFT>>(delta));
534 EncodeMapper<GallicArc<Arc, GALLIC_LEFT>> encoder(
535 kEncodeLabels | kEncodeWeights, ENCODE);
536 Encode(&gfst, &encoder);
537 internal::AcceptorMinimize(&gfst, allow_acyclic_minimization);
538 Decode(&gfst, encoder);
540 FactorWeightFst<GallicArc<Arc, GALLIC_LEFT>,
541 GallicFactor<typename Arc::Label, Weight, GALLIC_LEFT>>
543 std::unique_ptr<SymbolTable> osyms(
544 fst->OutputSymbols() ? fst->OutputSymbols()->Copy() : nullptr);
545 ArcMap(fwfst, fst, FromGallicMapper<Arc, GALLIC_LEFT>());
546 fst->SetOutputSymbols(osyms.get());
548 sfst->SetOutputSymbols(fst->OutputSymbols());
549 GallicToNewSymbolsMapper<Arc, GALLIC_LEFT> mapper(sfst);
550 ArcMap(gfst, fst, &mapper);
551 fst->SetOutputSymbols(sfst->InputSymbols());
553 } else if (props & kWeighted) { // Weighted acceptor.
554 Push(fst, REWEIGHT_TO_INITIAL, delta);
555 ArcMap(fst, QuantizeMapper<Arc>(delta));
556 EncodeMapper<Arc> encoder(kEncodeLabels | kEncodeWeights, ENCODE);
557 Encode(fst, &encoder);
558 internal::AcceptorMinimize(fst, allow_acyclic_minimization);
559 Decode(fst, encoder);
560 } else { // Unweighted acceptor.
561 internal::AcceptorMinimize(fst, allow_acyclic_minimization);
567 #endif // FST_MINIMIZE_H_