d68c43a71c0ff6cdf45f267a91d5c1680999535d
[platform/upstream/openfst.git] / src / include / fst / minimize.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions and classes to minimize an FST.
5
6 #ifndef FST_MINIMIZE_H_
7 #define FST_MINIMIZE_H_
8
9 #include <cmath>
10
11 #include <algorithm>
12 #include <map>
13 #include <queue>
14 #include <utility>
15 #include <vector>
16
17 #include <fst/log.h>
18
19 #include <fst/arcsort.h>
20 #include <fst/connect.h>
21 #include <fst/dfs-visit.h>
22 #include <fst/encode.h>
23 #include <fst/factor-weight.h>
24 #include <fst/fst.h>
25 #include <fst/mutable-fst.h>
26 #include <fst/partition.h>
27 #include <fst/push.h>
28 #include <fst/queue.h>
29 #include <fst/reverse.h>
30 #include <fst/state-map.h>
31
32
33 namespace fst {
34 namespace internal {
35
36 // Comparator for creating partition.
37 template <class Arc>
38 class StateComparator {
39  public:
40   using StateId = typename Arc::StateId;
41   using Weight = typename Arc::Weight;
42
43   StateComparator(const Fst<Arc> &fst, const Partition<StateId> &partition)
44       : fst_(fst), partition_(partition) {}
45
46   // Compares state x with state y based on sort criteria.
47   bool operator()(const StateId x, const StateId y) const {
48     // Checks for final state equivalence.
49     const auto xfinal = fst_.Final(x).Hash();
50     const auto yfinal = fst_.Final(y).Hash();
51     if (xfinal < yfinal) {
52       return true;
53     } else if (xfinal > yfinal) {
54       return false;
55     }
56     // Checks for number of arcs.
57     if (fst_.NumArcs(x) < fst_.NumArcs(y)) return true;
58     if (fst_.NumArcs(x) > fst_.NumArcs(y)) return false;
59     // If the number of arcs are equal, checks for arc match.
60     for (ArcIterator<Fst<Arc>> aiter1(fst_, x), aiter2(fst_, y);
61          !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) {
62       const auto &arc1 = aiter1.Value();
63       const auto &arc2 = aiter2.Value();
64       if (arc1.ilabel < arc2.ilabel) return true;
65       if (arc1.ilabel > arc2.ilabel) return false;
66       if (partition_.ClassId(arc1.nextstate) <
67           partition_.ClassId(arc2.nextstate))
68         return true;
69       if (partition_.ClassId(arc1.nextstate) >
70           partition_.ClassId(arc2.nextstate))
71         return false;
72     }
73     return false;
74   }
75
76  private:
77   const Fst<Arc> &fst_;
78   const Partition<StateId> &partition_;
79 };
80
81 // Computes equivalence classes for cyclic unweighted acceptors. For cyclic
82 // minimization we use the classic Hopcroft minimization algorithm, which has
83 // complexity O(E log V) where E is the number of arcs and V is the number of
84 // states.
85 //
86 // For more information, see:
87 //
88 //  Hopcroft, J. 1971. An n Log n algorithm for minimizing states in a finite
89 //  automaton. Ms, Stanford University.
90 //
91 // Note: the original presentation of the paper was for a finite automaton (==
92 // deterministic, unweighted acceptor), but we also apply it to the
93 // nondeterministic case, where it is also applicable as long as the semiring is
94 // idempotent (if the semiring is not idempotent, there are some complexities
95 // in keeping track of the weight when there are multiple arcs to states that
96 // will be merged, and we don't deal with this).
97 template <class Arc, class Queue>
98 class CyclicMinimizer {
99  public:
100   using Label = typename Arc::Label;
101   using StateId = typename Arc::StateId;
102   using ClassId = typename Arc::StateId;
103   using Weight = typename Arc::Weight;
104   using RevArc = ReverseArc<Arc>;
105
106   explicit CyclicMinimizer(const ExpandedFst<Arc> &fst) {
107     Initialize(fst);
108     Compute(fst);
109   }
110
111   const Partition<StateId> &GetPartition() const { return P_; }
112
113  private:
114   // StateILabelHasher is a hashing object that computes a hash-function
115   // of an FST state that depends only on the set of ilabels on arcs leaving
116   // the state [note: it assumes that the arcs are ilabel-sorted].
117   // In order to work correctly for non-deterministic automata, multiple
118   // instances of the same ilabel count the same as a single instance.
119   class StateILabelHasher {
120    public:
121     explicit StateILabelHasher(const Fst<Arc> &fst) : fst_(fst) {}
122
123     using Label = typename Arc::Label;
124     using StateId = typename Arc::StateId;
125
126     size_t operator()(const StateId s) {
127       const size_t p1 = 7603;
128       const size_t p2 = 433024223;
129       size_t result = p2;
130       size_t current_ilabel = kNoLabel;
131       for (ArcIterator<Fst<Arc>> aiter(fst_, s); !aiter.Done(); aiter.Next()) {
132         Label this_ilabel = aiter.Value().ilabel;
133         if (this_ilabel != current_ilabel) {  // Ignores repeats.
134           result = p1 * result + this_ilabel;
135           current_ilabel = this_ilabel;
136         }
137       }
138       return result;
139     }
140
141    private:
142     const Fst<Arc> &fst_;
143   };
144
145   class ArcIterCompare {
146    public:
147     explicit ArcIterCompare(const Partition<StateId> &partition)
148         : partition_(partition) {}
149
150     ArcIterCompare(const ArcIterCompare &comp) : partition_(comp.partition_) {}
151
152     // Compares two iterators based on their input labels.
153     bool operator()(const ArcIterator<Fst<RevArc>> *x,
154                     const ArcIterator<Fst<RevArc>> *y) const {
155       const auto &xarc = x->Value();
156       const auto &yarc = y->Value();
157       return xarc.ilabel > yarc.ilabel;
158     }
159
160    private:
161     const Partition<StateId> &partition_;
162   };
163
164   using ArcIterQueue =
165       std::priority_queue<ArcIterator<Fst<RevArc>> *,
166                           std::vector<ArcIterator<Fst<RevArc>> *>,
167                           ArcIterCompare>;
168
169  private:
170   // Prepartitions the space into equivalence classes. We ensure that final and
171   // non-final states always go into different equivalence classes, and we use
172   // class StateILabelHasher to make sure that most of the time, states with
173   // different sets of ilabels on arcs leaving them, go to different partitions.
174   // Note: for the O(n) guarantees we don't rely on the goodness of this
175   // hashing function---it just provides a bonus speedup.
176   void PrePartition(const ExpandedFst<Arc> &fst) {
177     VLOG(5) << "PrePartition";
178     StateId next_class = 0;
179     auto num_states = fst.NumStates();
180     // Allocates a temporary vector to store the initial class mappings, so that
181     // we can allocate the classes all at once.
182     std::vector<StateId> state_to_initial_class(num_states);
183     {
184       // We maintain two maps from hash-value to class---one for final states
185       // (final-prob == One()) and one for non-final states
186       // (final-prob == Zero()). We are processing unweighted acceptors, so the
187       // are the only two possible values.
188       using HashToClassMap = std::unordered_map<size_t, StateId>;
189       HashToClassMap hash_to_class_nonfinal;
190       HashToClassMap hash_to_class_final;
191       StateILabelHasher hasher(fst);
192       for (StateId s = 0; s < num_states; ++s) {
193         size_t hash = hasher(s);
194         HashToClassMap &this_map =
195             (fst.Final(s) != Weight::Zero() ? hash_to_class_final
196                                             : hash_to_class_nonfinal);
197         // Avoids two map lookups by using 'insert' instead of 'find'.
198         auto p = this_map.insert(std::make_pair(hash, next_class));
199         state_to_initial_class[s] = p.second ? next_class++ : p.first->second;
200       }
201       // Lets the unordered_maps go out of scope before we allocate the classes,
202       // to reduce the maximum amount of memory used.
203     }
204     P_.AllocateClasses(next_class);
205     for (StateId s = 0; s < num_states; ++s) {
206       P_.Add(s, state_to_initial_class[s]);
207     }
208     for (StateId c = 0; c < next_class; ++c) L_.Enqueue(c);
209     VLOG(5) << "Initial Partition: " << P_.NumClasses();
210   }
211
212   // Creates inverse transition Tr_ = rev(fst), loops over states in FST and
213   // splits on final, creating two blocks in the partition corresponding to
214   // final, non-final.
215   void Initialize(const ExpandedFst<Arc> &fst) {
216     // Constructs Tr.
217     Reverse(fst, &Tr_);
218     ILabelCompare<RevArc> ilabel_comp;
219     ArcSort(&Tr_, ilabel_comp);
220     // Tells the partition how many elements to allocate. The first state in
221     // Tr_ is super-final state.
222     P_.Initialize(Tr_.NumStates() - 1);
223     // Prepares initial partition.
224     PrePartition(fst);
225     // Allocates arc iterator queue.
226     ArcIterCompare comp(P_);
227     aiter_queue_.reset(new ArcIterQueue(comp));
228   }
229   // Partitions all classes with destination C.
230   void Split(ClassId C) {
231     // Prepares priority queue: opens arc iterator for each state in C, and
232     // inserts into priority queue.
233     for (PartitionIterator<StateId> siter(P_, C); !siter.Done(); siter.Next()) {
234       StateId s = siter.Value();
235       if (Tr_.NumArcs(s + 1)) {
236         aiter_queue_->push(new ArcIterator<Fst<RevArc>>(Tr_, s + 1));
237       }
238     }
239     // Now pops arc iterator from queue, splits entering equivalence class, and
240     // re-inserts updated iterator into queue.
241     Label prev_label = -1;
242     while (!aiter_queue_->empty()) {
243       std::unique_ptr<ArcIterator<Fst<RevArc>>> aiter(aiter_queue_->top());
244       aiter_queue_->pop();
245       if (aiter->Done()) continue;
246       const auto &arc = aiter->Value();
247       auto from_state = aiter->Value().nextstate - 1;
248       auto from_label = arc.ilabel;
249       if (prev_label != from_label) P_.FinalizeSplit(&L_);
250       auto from_class = P_.ClassId(from_state);
251       if (P_.ClassSize(from_class) > 1) P_.SplitOn(from_state);
252       prev_label = from_label;
253       aiter->Next();
254       if (!aiter->Done()) aiter_queue_->push(aiter.release());
255     }
256     P_.FinalizeSplit(&L_);
257   }
258
259   // Main loop for Hopcroft minimization.
260   void Compute(const Fst<Arc> &fst) {
261     // Processes active classes (FIFO, or FILO).
262     while (!L_.Empty()) {
263       const auto C = L_.Head();
264       L_.Dequeue();
265       Split(C);  // Splits on C, all labels in C.
266     }
267   }
268
269  private:
270   // Partioning of states into equivalence classes.
271   Partition<StateId> P_;
272   // Set of active classes to be processed in partition P.
273   Queue L_;
274   // Reverses transition function.
275   VectorFst<RevArc> Tr_;
276   // Priority queue of open arc iterators for all states in the splitter
277   // equivalence class.
278   std::unique_ptr<ArcIterQueue> aiter_queue_;
279 };
280
281 // Computes equivalence classes for acyclic FST.
282 //
283 // Complexity:
284 //
285 //   O(E)
286 //
287 // where E is the number of arcs.
288 //
289 // For more information, see:
290 //
291 // Revuz, D. 1992. Minimization of acyclic deterministic automata in linear
292 // time. Theoretical Computer Science 92(1): 181-189.
293 template <class Arc>
294 class AcyclicMinimizer {
295  public:
296   using Label = typename Arc::Label;
297   using StateId = typename Arc::StateId;
298   using ClassId = typename Arc::StateId;
299   using Weight = typename Arc::Weight;
300
301   explicit AcyclicMinimizer(const ExpandedFst<Arc> &fst) {
302     Initialize(fst);
303     Refine(fst);
304   }
305
306   const Partition<StateId> &GetPartition() { return partition_; }
307
308  private:
309   // DFS visitor to compute the height (distance) to final state.
310   class HeightVisitor {
311    public:
312     HeightVisitor() : max_height_(0), num_states_(0) {}
313
314     // Invoked before DFS visit.
315     void InitVisit(const Fst<Arc> &fst) {}
316
317     // Invoked when state is discovered (2nd arg is DFS tree root).
318     bool InitState(StateId s, StateId root) {
319       // Extends height array and initialize height (distance) to 0.
320       for (StateId i = height_.size(); i <= s; ++i) height_.push_back(-1);
321       if (s >= num_states_) num_states_ = s + 1;
322       return true;
323     }
324
325     // Invoked when tree arc examined (to undiscovered state).
326     bool TreeArc(StateId s, const Arc &arc) { return true; }
327
328     // Invoked when back arc examined (to unfinished state).
329     bool BackArc(StateId s, const Arc &arc) { return true; }
330
331     // Invoked when forward or cross arc examined (to finished state).
332     bool ForwardOrCrossArc(StateId s, const Arc &arc) {
333       if (height_[arc.nextstate] + 1 > height_[s]) {
334         height_[s] = height_[arc.nextstate] + 1;
335       }
336       return true;
337     }
338
339     // Invoked when state finished (parent is kNoStateId for tree root).
340     void FinishState(StateId s, StateId parent, const Arc *parent_arc) {
341       if (height_[s] == -1) height_[s] = 0;
342       const auto h = height_[s] + 1;
343       if (parent >= 0) {
344         if (h > height_[parent]) height_[parent] = h;
345         if (h > max_height_) max_height_ = h;
346       }
347     }
348
349     // Invoked after DFS visit.
350     void FinishVisit() {}
351
352     size_t max_height() const { return max_height_; }
353
354     const std::vector<StateId> &height() const { return height_; }
355
356     size_t num_states() const { return num_states_; }
357
358    private:
359     std::vector<StateId> height_;
360     size_t max_height_;
361     size_t num_states_;
362   };
363
364  private:
365   // Cluster states according to height (distance to final state)
366   void Initialize(const Fst<Arc> &fst) {
367     // Computes height (distance to final state).
368     HeightVisitor hvisitor;
369     DfsVisit(fst, &hvisitor);
370     // Creates initial partition based on height.
371     partition_.Initialize(hvisitor.num_states());
372     partition_.AllocateClasses(hvisitor.max_height() + 1);
373     const auto &hstates = hvisitor.height();
374     for (StateId s = 0; s < hstates.size(); ++s) partition_.Add(s, hstates[s]);
375   }
376
377   // Refines states based on arc sort (out degree, arc equivalence).
378   void Refine(const Fst<Arc> &fst) {
379     using EquivalenceMap = std::map<StateId, StateId, StateComparator<Arc>>;
380     StateComparator<Arc> comp(fst, partition_);
381     // Starts with tail (height = 0).
382     auto height = partition_.NumClasses();
383     for (StateId h = 0; h < height; ++h) {
384       EquivalenceMap equiv_classes(comp);
385       // Sorts states within equivalence class.
386       PartitionIterator<StateId> siter(partition_, h);
387       equiv_classes[siter.Value()] = h;
388       for (siter.Next(); !siter.Done(); siter.Next()) {
389         auto insert_result =
390             equiv_classes.insert(std::make_pair(siter.Value(), kNoStateId));
391         if (insert_result.second) {
392           insert_result.first->second = partition_.AddClass();
393         }
394       }
395       // Creates refined partition.
396       for (siter.Reset(); !siter.Done();) {
397         const auto s = siter.Value();
398         const auto old_class = partition_.ClassId(s);
399         const auto new_class = equiv_classes[s];
400         // A move operation can invalidate the iterator, so we first update
401         // the iterator to the next element before we move the current element
402         // out of the list.
403         siter.Next();
404         if (old_class != new_class) partition_.Move(s, new_class);
405       }
406     }
407   }
408
409  private:
410   Partition<StateId> partition_;
411 };
412
413 // Given a partition and a Mutable FST, merges states of Fst in place (i.e.,
414 // destructively). Merging works by taking the first state in a class of the
415 // partition to be the representative state for the class. Each arc is then
416 // reconnected to this state. All states in the class are merged by adding
417 // their arcs to the representative state.
418 template <class Arc>
419 void MergeStates(const Partition<typename Arc::StateId> &partition,
420                  MutableFst<Arc> *fst) {
421   using StateId = typename Arc::StateId;
422   std::vector<StateId> state_map(partition.NumClasses());
423   for (StateId i = 0; i < partition.NumClasses(); ++i) {
424     PartitionIterator<StateId> siter(partition, i);
425     state_map[i] = siter.Value();  // First state in partition.
426   }
427   // Relabels destination states.
428   for (StateId c = 0; c < partition.NumClasses(); ++c) {
429     for (PartitionIterator<StateId> siter(partition, c); !siter.Done();
430          siter.Next()) {
431       const auto s = siter.Value();
432       for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
433            aiter.Next()) {
434         auto arc = aiter.Value();
435         arc.nextstate = state_map[partition.ClassId(arc.nextstate)];
436         if (s == state_map[c]) {  // For the first state, just sets destination.
437           aiter.SetValue(arc);
438         } else {
439           fst->AddArc(state_map[c], arc);
440         }
441       }
442     }
443   }
444   fst->SetStart(state_map[partition.ClassId(fst->Start())]);
445   Connect(fst);
446 }
447
448 template <class Arc>
449 void AcceptorMinimize(MutableFst<Arc> *fst,
450                       bool allow_acyclic_minimization = true) {
451   if (!(fst->Properties(kAcceptor | kUnweighted, true) ==
452         (kAcceptor | kUnweighted))) {
453     FSTERROR() << "FST is not an unweighted acceptor";
454     fst->SetProperties(kError, kError);
455     return;
456   }
457   // Connects FST before minimization, handles disconnected states.
458   Connect(fst);
459   if (fst->NumStates() == 0) return;
460   if (allow_acyclic_minimization && fst->Properties(kAcyclic, true)) {
461     // Acyclic minimization (Revuz).
462     VLOG(2) << "Acyclic minimization";
463     ArcSort(fst, ILabelCompare<Arc>());
464     AcyclicMinimizer<Arc> minimizer(*fst);
465     MergeStates(minimizer.GetPartition(), fst);
466   } else {
467     // Either the FST has cycles, or it's generated from non-deterministic input
468     // (which the Revuz algorithm can't handle), so use the cyclic minimization
469     // algorithm of Hopcroft.
470     VLOG(2) << "Cyclic minimization";
471     CyclicMinimizer<Arc, LifoQueue<typename Arc::StateId>> minimizer(*fst);
472     MergeStates(minimizer.GetPartition(), fst);
473   }
474   // Merges in appropriate semiring
475   ArcUniqueMapper<Arc> mapper(*fst);
476   StateMap(fst, mapper);
477 }
478
479 }  // namespace internal
480
481 // In place minimization of deterministic weighted automata and transducers,
482 // and also non-deterministic ones if they use an idempotent semiring.
483 // For transducers, if the 'sfst' argument is not null, the algorithm
484 // produces a compact factorization of the minimal transducer.
485 //
486 // In the acyclic deterministic case, we use an algorithm from Revuz that is
487 // linear in the number of arcs (edges) in the machine.
488 //
489 // In the cyclic or non-deterministic case, we use the classical Hopcroft
490 // minimization (which was presented for the deterministic case but which
491 // also works for non-deterministic FSTs); this has complexity O(e log v).
492 //
493 template <class Arc>
494 void Minimize(MutableFst<Arc> *fst, MutableFst<Arc> *sfst = nullptr,
495               float delta = kDelta, bool allow_nondet = false) {
496   using Weight = typename Arc::Weight;
497   const auto props = fst->Properties(
498       kAcceptor | kIDeterministic | kWeighted | kUnweighted, true);
499   bool allow_acyclic_minimization;
500   if (props & kIDeterministic) {
501     allow_acyclic_minimization = true;
502   } else {
503     // Our approach to minimization of non-deterministic FSTs will only work in
504     // idempotent semirings---for non-deterministic inputs, a state could have
505     // multiple transitions to states that will get merged, and we'd have to
506     // sum their weights. The algorithm doesn't handle that.
507     if (!(Weight::Properties() & kIdempotent)) {
508       fst->SetProperties(kError, kError);
509       FSTERROR() << "Cannot minimize a non-deterministic FST over a "
510                     "non-idempotent semiring";
511       return;
512     } else if (!allow_nondet) {
513       fst->SetProperties(kError, kError);
514       FSTERROR() << "Refusing to minimize a non-deterministic FST with "
515                  << "allow_nondet = false";
516       return;
517     }
518     // The Revuz algorithm won't work for nondeterministic inputs, so if the
519     // input is nondeterministic, we'll have to pass a bool saying not to use
520     // that algorithm. We check at this level rather than in AcceptorMinimize(),
521     // because it's possible that the FST at this level could be deterministic,
522     // but a harmless type of non-determinism could be introduced by Encode()
523     // (thanks to kEncodeWeights, if the FST has epsilons and has a final
524     // weight with weights equal to some epsilon arc.)
525     allow_acyclic_minimization = false;
526   }
527   if (!(props & kAcceptor)) {  // Weighted transducer.
528     VectorFst<GallicArc<Arc, GALLIC_LEFT>> gfst;
529     ArcMap(*fst, &gfst, ToGallicMapper<Arc, GALLIC_LEFT>());
530     fst->DeleteStates();
531     gfst.SetProperties(kAcceptor, kAcceptor);
532     Push(&gfst, REWEIGHT_TO_INITIAL, delta);
533     ArcMap(&gfst, QuantizeMapper<GallicArc<Arc, GALLIC_LEFT>>(delta));
534     EncodeMapper<GallicArc<Arc, GALLIC_LEFT>> encoder(
535         kEncodeLabels | kEncodeWeights, ENCODE);
536     Encode(&gfst, &encoder);
537     internal::AcceptorMinimize(&gfst, allow_acyclic_minimization);
538     Decode(&gfst, encoder);
539     if (!sfst) {
540       FactorWeightFst<GallicArc<Arc, GALLIC_LEFT>,
541                       GallicFactor<typename Arc::Label, Weight, GALLIC_LEFT>>
542           fwfst(gfst);
543       std::unique_ptr<SymbolTable> osyms(
544           fst->OutputSymbols() ? fst->OutputSymbols()->Copy() : nullptr);
545       ArcMap(fwfst, fst, FromGallicMapper<Arc, GALLIC_LEFT>());
546       fst->SetOutputSymbols(osyms.get());
547     } else {
548       sfst->SetOutputSymbols(fst->OutputSymbols());
549       GallicToNewSymbolsMapper<Arc, GALLIC_LEFT> mapper(sfst);
550       ArcMap(gfst, fst, &mapper);
551       fst->SetOutputSymbols(sfst->InputSymbols());
552     }
553   } else if (props & kWeighted) {  // Weighted acceptor.
554     Push(fst, REWEIGHT_TO_INITIAL, delta);
555     ArcMap(fst, QuantizeMapper<Arc>(delta));
556     EncodeMapper<Arc> encoder(kEncodeLabels | kEncodeWeights, ENCODE);
557     Encode(fst, &encoder);
558     internal::AcceptorMinimize(fst, allow_acyclic_minimization);
559     Decode(fst, encoder);
560   } else {  // Unweighted acceptor.
561     internal::AcceptorMinimize(fst, allow_acyclic_minimization);
562   }
563 }
564
565 }  // namespace fst
566
567 #endif  // FST_MINIMIZE_H_