ee17b4188ec28d879c9ddc73d73162274dd7a698
[platform/upstream/openfst.git] / src / include / fst / rmepsilon.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions and classes that implemement epsilon-removal.
5
6 #ifndef FST_RMEPSILON_H_
7 #define FST_RMEPSILON_H_
8
9 #include <forward_list>
10 #include <stack>
11 #include <string>
12 #include <unordered_map>
13 #include <utility>
14 #include <vector>
15
16 #include <fst/log.h>
17
18 #include <fst/arcfilter.h>
19 #include <fst/cache.h>
20 #include <fst/connect.h>
21 #include <fst/factor-weight.h>
22 #include <fst/invert.h>
23 #include <fst/prune.h>
24 #include <fst/queue.h>
25 #include <fst/shortest-distance.h>
26 #include <fst/topsort.h>
27
28
29 namespace fst {
30
31 template <class Arc, class Queue>
32 class RmEpsilonOptions
33     : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc>> {
34  public:
35   using StateId = typename Arc::StateId;
36   using Weight = typename Arc::Weight;
37
38   bool connect;             // Connect output
39   Weight weight_threshold;  // Pruning weight threshold.
40   StateId state_threshold;  // Pruning state threshold.
41
42   explicit RmEpsilonOptions(Queue *queue, float delta = kDelta,
43                             bool connect = true,
44                             Weight weight_threshold = Weight::Zero(),
45                             StateId state_threshold = kNoStateId)
46       : ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc>>(
47             queue, EpsilonArcFilter<Arc>(), kNoStateId, delta),
48         connect(connect),
49         weight_threshold(std::move(weight_threshold)),
50         state_threshold(state_threshold) {}
51 };
52
53 namespace internal {
54
55 // Computation state of the epsilon-removal algorithm.
56 template <class Arc, class Queue>
57 class RmEpsilonState {
58  public:
59   using Label = typename Arc::Label;
60   using StateId = typename Arc::StateId;
61   using Weight = typename Arc::Weight;
62
63   RmEpsilonState(const Fst<Arc> &fst, std::vector<Weight> *distance,
64                  const RmEpsilonOptions<Arc, Queue> &opts)
65       : fst_(fst),
66         distance_(distance),
67         sd_state_(fst_, distance, opts, true),
68         expand_id_(0) {}
69
70   void Expand(StateId s);
71
72   std::vector<Arc> &Arcs() { return arcs_; }
73
74   const Weight &Final() const { return final_; }
75
76   bool Error() const { return sd_state_.Error(); }
77
78  private:
79   struct Element {
80     Label ilabel;
81     Label olabel;
82     StateId nextstate;
83
84     Element() {}
85
86     Element(Label ilabel, Label olabel, StateId nexstate)
87         : ilabel(ilabel), olabel(olabel), nextstate(nexstate) {}
88   };
89
90   struct ElementHash {
91    public:
92     size_t operator()(const Element &element) const {
93       static constexpr size_t prime0 = 7853;
94       static constexpr size_t prime1 = 7867;
95       return static_cast<size_t>(element.nextstate) +
96              static_cast<size_t>(element.ilabel) * prime0 +
97              static_cast<size_t>(element.olabel) * prime1;
98     }
99   };
100
101   class ElementEqual {
102    public:
103     bool operator()(const Element &e1, const Element &e2) const {
104       return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel) &&
105              (e1.nextstate == e2.nextstate);
106     }
107   };
108
109   using ElementMap = std::unordered_map<Element, std::pair<StateId, size_t>,
110                                         ElementHash, ElementEqual>;
111
112   const Fst<Arc> &fst_;
113   // Distance from state being expanded in epsilon-closure.
114   std::vector<Weight> *distance_;
115   // Shortest distance algorithm computation state.
116   internal::ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc>> sd_state_;
117   // Maps an element to a pair corresponding to a position in the arcs vector
118   // of the state being expanded. The element corresopnds to the position in
119   // the arcs_ vector if p.first is equal to the state being expanded.
120   ElementMap element_map_;
121   EpsilonArcFilter<Arc> eps_filter_;
122   std::stack<StateId> eps_queue_;  // Queue used to visit the epsilon-closure.
123   std::vector<bool> visited_;      // True if the state has been visited.
124   std::forward_list<StateId> visited_states_;  // List of visited states.
125   std::vector<Arc> arcs_;                      // Arcs of state being expanded.
126   Weight final_;       // Final weight of state being expanded.
127   StateId expand_id_;  // Unique ID for each call to Expand
128
129   RmEpsilonState(const RmEpsilonState &) = delete;
130   RmEpsilonState &operator=(const RmEpsilonState &) = delete;
131 };
132
133 template <class Arc, class Queue>
134 void RmEpsilonState<Arc, Queue>::Expand(typename Arc::StateId source) {
135   final_ = Weight::Zero();
136   arcs_.clear();
137   sd_state_.ShortestDistance(source);
138   if (sd_state_.Error()) return;
139   eps_queue_.push(source);
140   while (!eps_queue_.empty()) {
141     const auto state = eps_queue_.top();
142     eps_queue_.pop();
143     while (visited_.size() <= state) visited_.push_back(false);
144     if (visited_[state]) continue;
145     visited_[state] = true;
146     visited_states_.push_front(state);
147     for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
148          aiter.Next()) {
149       auto arc = aiter.Value();
150       arc.weight = Times((*distance_)[state], arc.weight);
151       if (eps_filter_(arc)) {
152         while (visited_.size() <= arc.nextstate) visited_.push_back(false);
153         if (!visited_[arc.nextstate]) eps_queue_.push(arc.nextstate);
154       } else {
155         const Element element(arc.ilabel, arc.olabel, arc.nextstate);
156         auto insert_result = element_map_.insert(
157             std::make_pair(element, std::make_pair(expand_id_, arcs_.size())));
158         if (insert_result.second) {
159           arcs_.push_back(arc);
160         } else {
161           if (insert_result.first->second.first == expand_id_) {
162             auto &weight = arcs_[insert_result.first->second.second].weight;
163             weight = Plus(weight, arc.weight);
164           } else {
165             insert_result.first->second.first = expand_id_;
166             insert_result.first->second.second = arcs_.size();
167             arcs_.push_back(arc);
168           }
169         }
170       }
171     }
172     final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state)));
173   }
174   while (!visited_states_.empty()) {
175     visited_[visited_states_.front()] = false;
176     visited_states_.pop_front();
177   }
178   ++expand_id_;
179 }
180
181 }  // namespace internal
182
183 // Removes epsilon-transitions (when both the input and output label are an
184 // epsilon) from a transducer. The result will be an equivalent FST that has no
185 // such epsilon transitions. This version modifies its input. It allows fine
186 // control via the options argument; see below for a simpler interface.
187 //
188 // The distance vector will be used to hold the shortest distances during the
189 // epsilon-closure computation. The state queue discipline and convergence delta
190 // are taken in the options argument.
191 template <class Arc, class Queue>
192 void RmEpsilon(MutableFst<Arc> *fst,
193                std::vector<typename Arc::Weight> *distance,
194                const RmEpsilonOptions<Arc, Queue> &opts) {
195   using Label = typename Arc::Label;
196   using StateId = typename Arc::StateId;
197   using Weight = typename Arc::Weight;
198   if (fst->Start() == kNoStateId) return;
199   // noneps_in[s] will be set to true iff s admits a non-epsilon incoming
200   // transition or is the start state.
201   std::vector<bool> noneps_in(fst->NumStates(), false);
202   noneps_in[fst->Start()] = true;
203   for (size_t i = 0; i < fst->NumStates(); ++i) {
204     for (ArcIterator<Fst<Arc>> aiter(*fst, i); !aiter.Done(); aiter.Next()) {
205       const auto &arc = aiter.Value();
206       if (arc.ilabel != 0 || arc.olabel != 0) {
207         noneps_in[arc.nextstate] = true;
208       }
209     }
210   }
211   // States sorted in topological order when (acyclic) or generic topological
212   // order (cyclic).
213   std::vector<StateId> states;
214   states.reserve(fst->NumStates());
215   if (fst->Properties(kTopSorted, false) & kTopSorted) {
216     for (size_t i = 0; i < fst->NumStates(); i++) states.push_back(i);
217   } else if (fst->Properties(kAcyclic, false) & kAcyclic) {
218     std::vector<StateId> order;
219     bool acyclic;
220     TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
221     DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>());
222     // Sanity check: should be acyclic if property bit is set.
223     if (!acyclic) {
224       FSTERROR() << "RmEpsilon: Inconsistent acyclic property bit";
225       fst->SetProperties(kError, kError);
226       return;
227     }
228     states.resize(order.size());
229     for (StateId i = 0; i < order.size(); i++) states[order[i]] = i;
230   } else {
231     uint64 props;
232     std::vector<StateId> scc;
233     SccVisitor<Arc> scc_visitor(&scc, nullptr, nullptr, &props);
234     DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>());
235     std::vector<StateId> first(scc.size(), kNoStateId);
236     std::vector<StateId> next(scc.size(), kNoStateId);
237     for (StateId i = 0; i < scc.size(); i++) {
238       if (first[scc[i]] != kNoStateId) next[i] = first[scc[i]];
239       first[scc[i]] = i;
240     }
241     for (StateId i = 0; i < first.size(); i++) {
242       for (auto j = first[i]; j != kNoStateId; j = next[j]) {
243         states.push_back(j);
244       }
245     }
246   }
247   internal::RmEpsilonState<Arc, Queue> rmeps_state(*fst, distance, opts);
248   while (!states.empty()) {
249     const auto state = states.back();
250     states.pop_back();
251     if (!noneps_in[state] &&
252         (opts.connect || opts.weight_threshold != Weight::Zero() ||
253          opts.state_threshold != kNoStateId)) {
254       continue;
255     }
256     rmeps_state.Expand(state);
257     fst->SetFinal(state, rmeps_state.Final());
258     fst->DeleteArcs(state);
259     auto &arcs = rmeps_state.Arcs();
260     fst->ReserveArcs(state, arcs.size());
261     while (!arcs.empty()) {
262       fst->AddArc(state, arcs.back());
263       arcs.pop_back();
264     }
265   }
266   if (opts.connect || opts.weight_threshold != Weight::Zero() ||
267       opts.state_threshold != kNoStateId) {
268     for (size_t s = 0; s < fst->NumStates(); ++s) {
269       if (!noneps_in[s]) fst->DeleteArcs(s);
270     }
271   }
272   if (rmeps_state.Error()) fst->SetProperties(kError, kError);
273   fst->SetProperties(
274       RmEpsilonProperties(fst->Properties(kFstProperties, false)),
275       kFstProperties);
276   if (opts.weight_threshold != Weight::Zero() ||
277       opts.state_threshold != kNoStateId) {
278     Prune(fst, opts.weight_threshold, opts.state_threshold);
279   }
280   if (opts.connect && opts.weight_threshold == Weight::Zero() &&
281       opts.state_threshold == kNoStateId) {
282     Connect(fst);
283   }
284 }
285
286 // Removes epsilon-transitions (when both the input and output label
287 // are an epsilon) from a transducer. The result will be an equivalent
288 // FST that has no such epsilon transitions. This version modifies its
289 // input. It has a simplified interface; see above for a version that
290 // allows finer control.
291 //
292 // Complexity:
293 //
294 // - Time:
295 //
296 //   Unweighted: O(v^2 + ve).
297 //   Acyclic: O(v^2 + V e).
298 //   Tropical semiring: O(v^2 log V + ve).
299 //   General: exponential.
300 //
301 // - Space: O(vE)
302 //
303 // where v is the number of states visited and e is the number of arcs visited.
304 //
305 // For more information, see:
306 //
307 // Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization
308 // algorithms for weighted transducers. International Journal of Computer
309 // Science 13(1): 129-143.
310 template <class Arc>
311 void RmEpsilon(MutableFst<Arc> *fst, bool connect = true,
312                typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
313                typename Arc::StateId state_threshold = kNoStateId,
314                float delta = kDelta) {
315   using StateId = typename Arc::StateId;
316   using Weight = typename Arc::Weight;
317   std::vector<Weight> distance;
318   AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>());
319   RmEpsilonOptions<Arc, AutoQueue<StateId>> opts(
320       &state_queue, delta, connect, weight_threshold, state_threshold);
321   RmEpsilon(fst, &distance, opts);
322 }
323
324 struct RmEpsilonFstOptions : CacheOptions {
325   float delta;
326
327   explicit RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta)
328       : CacheOptions(opts), delta(delta) {}
329
330   explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {}
331 };
332
333 namespace internal {
334
335 // Implementation of delayed RmEpsilonFst.
336 template <class Arc>
337 class RmEpsilonFstImpl : public CacheImpl<Arc> {
338  public:
339   using StateId = typename Arc::StateId;
340   using Weight = typename Arc::Weight;
341
342   using Store = DefaultCacheStore<Arc>;
343   using State = typename Store::State;
344
345   using FstImpl<Arc>::Properties;
346   using FstImpl<Arc>::SetType;
347   using FstImpl<Arc>::SetProperties;
348   using FstImpl<Arc>::SetInputSymbols;
349   using FstImpl<Arc>::SetOutputSymbols;
350
351   using CacheBaseImpl<CacheState<Arc>>::HasArcs;
352   using CacheBaseImpl<CacheState<Arc>>::HasFinal;
353   using CacheBaseImpl<CacheState<Arc>>::HasStart;
354   using CacheBaseImpl<CacheState<Arc>>::PushArc;
355   using CacheBaseImpl<CacheState<Arc>>::SetArcs;
356   using CacheBaseImpl<CacheState<Arc>>::SetFinal;
357   using CacheBaseImpl<CacheState<Arc>>::SetStart;
358
359   RmEpsilonFstImpl(const Fst<Arc> &fst, const RmEpsilonFstOptions &opts)
360       : CacheImpl<Arc>(opts),
361         fst_(fst.Copy()),
362         delta_(opts.delta),
363         rmeps_state_(
364             *fst_, &distance_,
365             RmEpsilonOptions<Arc, FifoQueue<StateId>>(&queue_, delta_, false)) {
366     SetType("rmepsilon");
367     SetProperties(
368         RmEpsilonProperties(fst.Properties(kFstProperties, false), true),
369         kCopyProperties);
370     SetInputSymbols(fst.InputSymbols());
371     SetOutputSymbols(fst.OutputSymbols());
372   }
373
374   RmEpsilonFstImpl(const RmEpsilonFstImpl &impl)
375       : CacheImpl<Arc>(impl),
376         fst_(impl.fst_->Copy(true)),
377         delta_(impl.delta_),
378         rmeps_state_(
379             *fst_, &distance_,
380             RmEpsilonOptions<Arc, FifoQueue<StateId>>(&queue_, delta_, false)) {
381     SetType("rmepsilon");
382     SetProperties(impl.Properties(), kCopyProperties);
383     SetInputSymbols(impl.InputSymbols());
384     SetOutputSymbols(impl.OutputSymbols());
385   }
386
387   StateId Start() {
388     if (!HasStart()) SetStart(fst_->Start());
389     return CacheImpl<Arc>::Start();
390   }
391
392   Weight Final(StateId s) {
393     if (!HasFinal(s)) Expand(s);
394     return CacheImpl<Arc>::Final(s);
395   }
396
397   size_t NumArcs(StateId s) {
398     if (!HasArcs(s)) Expand(s);
399     return CacheImpl<Arc>::NumArcs(s);
400   }
401
402   size_t NumInputEpsilons(StateId s) {
403     if (!HasArcs(s)) Expand(s);
404     return CacheImpl<Arc>::NumInputEpsilons(s);
405   }
406
407   size_t NumOutputEpsilons(StateId s) {
408     if (!HasArcs(s)) Expand(s);
409     return CacheImpl<Arc>::NumOutputEpsilons(s);
410   }
411
412   uint64 Properties() const override { return Properties(kFstProperties); }
413
414   // Sets error if found and returns other FST impl properties.
415   uint64 Properties(uint64 mask) const override {
416     if ((mask & kError) &&
417         (fst_->Properties(kError, false) || rmeps_state_.Error())) {
418       SetProperties(kError, kError);
419     }
420     return FstImpl<Arc>::Properties(mask);
421   }
422
423   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
424     if (!HasArcs(s)) Expand(s);
425     CacheImpl<Arc>::InitArcIterator(s, data);
426   }
427
428   void Expand(StateId s) {
429     rmeps_state_.Expand(s);
430     SetFinal(s, rmeps_state_.Final());
431     auto &arcs = rmeps_state_.Arcs();
432     while (!arcs.empty()) {
433       PushArc(s, arcs.back());
434       arcs.pop_back();
435     }
436     SetArcs(s);
437   }
438
439  private:
440   std::unique_ptr<const Fst<Arc>> fst_;
441   float delta_;
442   std::vector<Weight> distance_;
443   FifoQueue<StateId> queue_;
444   internal::RmEpsilonState<Arc, FifoQueue<StateId>> rmeps_state_;
445 };
446
447 }  // namespace internal
448
449 // Removes epsilon-transitions (when both the input and output label are an
450 // epsilon) from a transducer. The result will be an equivalent FST that has no
451 // such epsilon transitions. This version is a
452 // delayed FST.
453 //
454 // Complexity:
455 //
456 // - Time:
457 //   Unweighted: O(v^2 + ve).
458 //   General: exponential.
459 //
460 // - Space: O(vE)
461 //
462 // where v is the number of states visited and e is the number of arcs visited.
463 // Constant time to visit an input state or arc is assumed and exclusive of
464 // caching.
465 //
466 // For more information, see:
467 //
468 // Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization
469 // algorithms for weighted transducers. International Journal of Computer
470 // Science 13(1): 129-143.
471 //
472 // This class attaches interface to implementation and handles
473 // reference counting, delegating most methods to ImplToFst.
474 template <class A>
475 class RmEpsilonFst : public ImplToFst<internal::RmEpsilonFstImpl<A>> {
476  public:
477   using Arc = A;
478   using StateId = typename Arc::StateId;
479
480   using Store = DefaultCacheStore<Arc>;
481   using State = typename Store::State;
482   using Impl = internal::RmEpsilonFstImpl<Arc>;
483
484   friend class ArcIterator<RmEpsilonFst<Arc>>;
485   friend class StateIterator<RmEpsilonFst<Arc>>;
486
487   explicit RmEpsilonFst(const Fst<Arc> &fst)
488       : ImplToFst<Impl>(std::make_shared<Impl>(fst, RmEpsilonFstOptions())) {}
489
490   RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts)
491       : ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
492
493   // See Fst<>::Copy() for doc.
494   RmEpsilonFst(const RmEpsilonFst<Arc> &fst, bool safe = false)
495       : ImplToFst<Impl>(fst, safe) {}
496
497   // Get a copy of this RmEpsilonFst. See Fst<>::Copy() for further doc.
498   RmEpsilonFst<Arc> *Copy(bool safe = false) const override {
499     return new RmEpsilonFst<Arc>(*this, safe);
500   }
501
502   inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
503
504   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
505     GetMutableImpl()->InitArcIterator(s, data);
506   }
507
508  private:
509   using ImplToFst<Impl>::GetImpl;
510   using ImplToFst<Impl>::GetMutableImpl;
511
512   RmEpsilonFst &operator=(const RmEpsilonFst &) = delete;
513 };
514
515 // Specialization for RmEpsilonFst.
516 template <class Arc>
517 class StateIterator<RmEpsilonFst<Arc>>
518     : public CacheStateIterator<RmEpsilonFst<Arc>> {
519  public:
520   explicit StateIterator(const RmEpsilonFst<Arc> &fst)
521       : CacheStateIterator<RmEpsilonFst<Arc>>(fst, fst.GetMutableImpl()) {}
522 };
523
524 // Specialization for RmEpsilonFst.
525 template <class Arc>
526 class ArcIterator<RmEpsilonFst<Arc>>
527     : public CacheArcIterator<RmEpsilonFst<Arc>> {
528  public:
529   using StateId = typename Arc::StateId;
530
531   ArcIterator(const RmEpsilonFst<Arc> &fst, StateId s)
532       : CacheArcIterator<RmEpsilonFst<Arc>>(fst.GetMutableImpl(), s) {
533     if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
534   }
535 };
536
537 template <class Arc>
538 inline void RmEpsilonFst<Arc>::InitStateIterator(
539     StateIteratorData<Arc> *data) const {
540   data->base = new StateIterator<RmEpsilonFst<Arc>>(*this);
541 }
542
543 // Useful alias when using StdArc.
544 using StdRmEpsilonFst = RmEpsilonFst<StdArc>;
545
546 }  // namespace fst
547
548 #endif  // FST_RMEPSILON_H_