d36158da8c0c9c7608ab9a6d396e2bd466c98683
[platform/upstream/openfst.git] / src / include / fst / rmepsilon.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions and classes that implemement epsilon-removal.
5
6 #ifndef FST_LIB_RMEPSILON_H_
7 #define FST_LIB_RMEPSILON_H_
8
9 #include <forward_list>
10 #include <stack>
11 #include <string>
12 #include <unordered_map>
13 #include <utility>
14 #include <vector>
15
16 #include <fst/log.h>
17
18 #include <fst/arcfilter.h>
19 #include <fst/cache.h>
20 #include <fst/connect.h>
21 #include <fst/factor-weight.h>
22 #include <fst/invert.h>
23 #include <fst/prune.h>
24 #include <fst/queue.h>
25 #include <fst/shortest-distance.h>
26 #include <fst/topsort.h>
27
28
29 namespace fst {
30
31 template <class Arc, class Queue>
32 class RmEpsilonOptions
33     : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc>> {
34  public:
35   using StateId = typename Arc::StateId;
36   using Weight = typename Arc::Weight;
37
38   bool connect;             // Connect output
39   Weight weight_threshold;  // Pruning weight threshold.
40   StateId state_threshold;  // Pruning state threshold.
41
42   explicit RmEpsilonOptions(Queue *queue, float delta = kDelta,
43                             bool connect = true,
44                             Weight weight_threshold = Weight::Zero(),
45                             StateId state_threshold = kNoStateId)
46       : ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc>>(
47             queue, EpsilonArcFilter<Arc>(), kNoStateId, delta),
48         connect(connect),
49         weight_threshold(std::move(weight_threshold)),
50         state_threshold(state_threshold) {}
51 };
52
53 namespace internal {
54
55 // Computation state of the epsilon-removal algorithm.
56 template <class Arc, class Queue>
57 class RmEpsilonState {
58  public:
59   using Label = typename Arc::Label;
60   using StateId = typename Arc::StateId;
61   using Weight = typename Arc::Weight;
62
63   RmEpsilonState(const Fst<Arc> &fst, std::vector<Weight> *distance,
64                  const RmEpsilonOptions<Arc, Queue> &opts)
65       : fst_(fst),
66         distance_(distance),
67         sd_state_(fst_, distance, opts, true),
68         expand_id_(0) {}
69
70   void Expand(StateId s);
71
72   std::vector<Arc> &Arcs() { return arcs_; }
73
74   const Weight &Final() const { return final_; }
75
76   bool Error() const { return sd_state_.Error(); }
77
78  private:
79   struct Element {
80     Label ilabel;
81     Label olabel;
82     StateId nextstate;
83
84     Element() {}
85
86     Element(Label ilabel, Label olabel, StateId nexstate)
87         : ilabel(ilabel), olabel(olabel), nextstate(nexstate) {}
88   };
89
90   struct ElementHash {
91    public:
92     size_t operator()(const Element &element) const {
93       static constexpr auto prime0 = 7853;
94       static constexpr auto prime1 = 7867;
95       return element.nextstate + element.ilabel * prime0 +
96              element.olabel * prime1;
97     }
98   };
99
100   class ElementEqual {
101    public:
102     bool operator()(const Element &e1, const Element &e2) const {
103       return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel) &&
104              (e1.nextstate == e2.nextstate);
105     }
106   };
107
108   using ElementMap = std::unordered_map<Element, std::pair<StateId, size_t>,
109                                         ElementHash, ElementEqual>;
110
111   const Fst<Arc> &fst_;
112   // Distance from state being expanded in epsilon-closure.
113   std::vector<Weight> *distance_;
114   // Shortest distance algorithm computation state.
115   internal::ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc>> sd_state_;
116   // Maps an element to a pair corresponding to a position in the arcs vector
117   // of the state being expanded. The element corresopnds to the position in
118   // the arcs_ vector if p.first is equal to the state being expanded.
119   ElementMap element_map_;
120   EpsilonArcFilter<Arc> eps_filter_;
121   std::stack<StateId> eps_queue_;  // Queue used to visit the epsilon-closure.
122   std::vector<bool> visited_;      // True if the state has been visited.
123   std::forward_list<StateId> visited_states_;  // List of visited states.
124   std::vector<Arc> arcs_;                      // Arcs of state being expanded
125   Weight final_;       // Final weight of state being expanded.
126   StateId expand_id_;  // Unique ID for each call to Expand
127
128   RmEpsilonState(const RmEpsilonState &) = delete;
129   RmEpsilonState &operator=(const RmEpsilonState &) = delete;
130 };
131
132 template <class Arc, class Queue>
133 void RmEpsilonState<Arc, Queue>::Expand(typename Arc::StateId source) {
134   final_ = Weight::Zero();
135   arcs_.clear();
136   sd_state_.ShortestDistance(source);
137   if (sd_state_.Error()) return;
138   eps_queue_.push(source);
139   while (!eps_queue_.empty()) {
140     const auto state = eps_queue_.top();
141     eps_queue_.pop();
142     while (visited_.size() <= state) visited_.push_back(false);
143     if (visited_[state]) continue;
144     visited_[state] = true;
145     visited_states_.push_front(state);
146     for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
147          aiter.Next()) {
148       auto arc = aiter.Value();
149       arc.weight = Times((*distance_)[state], arc.weight);
150       if (eps_filter_(arc)) {
151         while (visited_.size() <= arc.nextstate) visited_.push_back(false);
152         if (!visited_[arc.nextstate]) eps_queue_.push(arc.nextstate);
153       } else {
154         const Element element(arc.ilabel, arc.olabel, arc.nextstate);
155         auto insert_result = element_map_.insert(
156             std::make_pair(element, std::make_pair(expand_id_, arcs_.size())));
157         if (insert_result.second) {
158           arcs_.push_back(arc);
159         } else {
160           if (insert_result.first->second.first == expand_id_) {
161             auto &weight = arcs_[insert_result.first->second.second].weight;
162             weight = Plus(weight, arc.weight);
163           } else {
164             insert_result.first->second.first = expand_id_;
165             insert_result.first->second.second = arcs_.size();
166             arcs_.push_back(arc);
167           }
168         }
169       }
170     }
171     final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state)));
172   }
173   while (!visited_states_.empty()) {
174     visited_[visited_states_.front()] = false;
175     visited_states_.pop_front();
176   }
177   ++expand_id_;
178 }
179
180 }  // namespace internal
181
182 // Removes epsilon-transitions (when both the input and output label are an
183 // epsilon) from a transducer. The result will be an equivalent FST that has no
184 // such epsilon transitions. This version modifies its input. It allows fine
185 // control via the options argument; see below for a simpler interface.
186 //
187 // Thei distance vector will be used to hold the shortest distances during the
188 // epsilon-closure computation. The state queue discipline and convergence delta
189 // are taken in the options argument.
190 template <class Arc, class Queue>
191 void RmEpsilon(MutableFst<Arc> *fst,
192                std::vector<typename Arc::Weight> *distance,
193                const RmEpsilonOptions<Arc, Queue> &opts) {
194   using Label = typename Arc::Label;
195   using StateId = typename Arc::StateId;
196   using Weight = typename Arc::Weight;
197   if (fst->Start() == kNoStateId) return;
198   // noneps_in[s] will be set to true iff s'admits a non-epsilon incoming
199   // transition or is the start state.
200   std::vector<bool> noneps_in(fst->NumStates(), false);
201   noneps_in[fst->Start()] = true;
202   for (size_t i = 0; i < fst->NumStates(); ++i) {
203     for (ArcIterator<Fst<Arc>> aiter(*fst, i); !aiter.Done(); aiter.Next()) {
204       const auto &arc = aiter.Value();
205       if (arc.ilabel != 0 || arc.olabel != 0) {
206         noneps_in[arc.nextstate] = true;
207       }
208     }
209   }
210   // States sorted in topological order when (acyclic) or generic topological
211   // order (cyclic).
212   std::vector<StateId> states;
213   states.reserve(fst->NumStates());
214   if (fst->Properties(kTopSorted, false) & kTopSorted) {
215     for (size_t i = 0; i < fst->NumStates(); i++) states.push_back(i);
216   } else if (fst->Properties(kAcyclic, false) & kAcyclic) {
217     std::vector<StateId> order;
218     bool acyclic;
219     TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
220     DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>());
221     // Sanity check: should be acyclic if property bit is set.
222     if (!acyclic) {
223       FSTERROR() << "RmEpsilon: Inconsistent acyclic property bit";
224       fst->SetProperties(kError, kError);
225       return;
226     }
227     states.resize(order.size());
228     for (StateId i = 0; i < order.size(); i++) states[order[i]] = i;
229   } else {
230     uint64 props;
231     std::vector<StateId> scc;
232     SccVisitor<Arc> scc_visitor(&scc, nullptr, nullptr, &props);
233     DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>());
234     std::vector<StateId> first(scc.size(), kNoStateId);
235     std::vector<StateId> next(scc.size(), kNoStateId);
236     for (StateId i = 0; i < scc.size(); i++) {
237       if (first[scc[i]] != kNoStateId) next[i] = first[scc[i]];
238       first[scc[i]] = i;
239     }
240     for (StateId i = 0; i < first.size(); i++) {
241       for (auto j = first[i]; j != kNoStateId; j = next[j]) {
242         states.push_back(j);
243       }
244     }
245   }
246   internal::RmEpsilonState<Arc, Queue> rmeps_state(*fst, distance, opts);
247   while (!states.empty()) {
248     const auto state = states.back();
249     states.pop_back();
250     if (!noneps_in[state] &&
251         (opts.connect || opts.weight_threshold != Weight::Zero() ||
252          opts.state_threshold != kNoStateId)) {
253       continue;
254     }
255     rmeps_state.Expand(state);
256     fst->SetFinal(state, rmeps_state.Final());
257     fst->DeleteArcs(state);
258     auto &arcs = rmeps_state.Arcs();
259     fst->ReserveArcs(state, arcs.size());
260     while (!arcs.empty()) {
261       fst->AddArc(state, arcs.back());
262       arcs.pop_back();
263     }
264   }
265   if (opts.connect || opts.weight_threshold != Weight::Zero() ||
266       opts.state_threshold != kNoStateId) {
267     for (size_t s = 0; s < fst->NumStates(); ++s) {
268       if (!noneps_in[s]) fst->DeleteArcs(s);
269     }
270   }
271   if (rmeps_state.Error()) fst->SetProperties(kError, kError);
272   fst->SetProperties(
273       RmEpsilonProperties(fst->Properties(kFstProperties, false)),
274       kFstProperties);
275   if (opts.weight_threshold != Weight::Zero() ||
276       opts.state_threshold != kNoStateId) {
277     Prune(fst, opts.weight_threshold, opts.state_threshold);
278   }
279   if (opts.connect && opts.weight_threshold == Weight::Zero() &&
280       opts.state_threshold == kNoStateId) {
281     Connect(fst);
282   }
283 }
284
285 // Removes epsilon-transitions (when both the input and output label
286 // are an epsilon) from a transducer. The result will be an equivalent
287 // FST that has no such epsilon transitions. This version modifies its
288 // input. It has a simplified interface; see above for a version that
289 // allows finer control.
290 //
291 // Complexity:
292 //
293 // - Time:
294 //
295 //   Unweighted: O(v^2 + ve).
296 //   Acyclic: O(v^2 + V e).
297 //   Tropical semiring: O(v^2 log V + ve).
298 //   General: exponential.
299 //
300 // - Space: O(vE)
301 //
302 // where v is the number of states visited and e is the number of arcs visited.
303 //
304 // For more information, see:
305 //
306 // Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization
307 // algorithms for weighted transducers. International Journal of Computer
308 // Science 13(1): 129-143.
309 template <class Arc>
310 void RmEpsilon(MutableFst<Arc> *fst, bool connect = true,
311                typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
312                typename Arc::StateId state_threshold = kNoStateId,
313                float delta = kDelta) {
314   using StateId = typename Arc::StateId;
315   using Weight = typename Arc::Weight;
316   std::vector<Weight> distance;
317   AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>());
318   RmEpsilonOptions<Arc, AutoQueue<StateId>> opts(
319       &state_queue, delta, connect, weight_threshold, state_threshold);
320   RmEpsilon(fst, &distance, opts);
321 }
322
323 struct RmEpsilonFstOptions : CacheOptions {
324   float delta;
325
326   explicit RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta)
327       : CacheOptions(opts), delta(delta) {}
328
329   explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {}
330 };
331
332 namespace internal {
333
334 // Implementation of delayed RmEpsilonFst.
335 template <class Arc>
336 class RmEpsilonFstImpl : public CacheImpl<Arc> {
337  public:
338   using StateId = typename Arc::StateId;
339   using Weight = typename Arc::Weight;
340
341   using Store = DefaultCacheStore<Arc>;
342   using State = typename Store::State;
343
344   using FstImpl<Arc>::Properties;
345   using FstImpl<Arc>::SetType;
346   using FstImpl<Arc>::SetProperties;
347   using FstImpl<Arc>::SetInputSymbols;
348   using FstImpl<Arc>::SetOutputSymbols;
349
350   using CacheBaseImpl<CacheState<Arc>>::HasArcs;
351   using CacheBaseImpl<CacheState<Arc>>::HasFinal;
352   using CacheBaseImpl<CacheState<Arc>>::HasStart;
353   using CacheBaseImpl<CacheState<Arc>>::PushArc;
354   using CacheBaseImpl<CacheState<Arc>>::SetArcs;
355   using CacheBaseImpl<CacheState<Arc>>::SetFinal;
356   using CacheBaseImpl<CacheState<Arc>>::SetStart;
357
358   RmEpsilonFstImpl(const Fst<Arc> &fst, const RmEpsilonFstOptions &opts)
359       : CacheImpl<Arc>(opts),
360         fst_(fst.Copy()),
361         delta_(opts.delta),
362         rmeps_state_(
363             *fst_, &distance_,
364             RmEpsilonOptions<Arc, FifoQueue<StateId>>(&queue_, delta_, false)) {
365     SetType("rmepsilon");
366     SetProperties(
367         RmEpsilonProperties(fst.Properties(kFstProperties, false), true),
368         kCopyProperties);
369     SetInputSymbols(fst.InputSymbols());
370     SetOutputSymbols(fst.OutputSymbols());
371   }
372
373   RmEpsilonFstImpl(const RmEpsilonFstImpl &impl)
374       : CacheImpl<Arc>(impl),
375         fst_(impl.fst_->Copy(true)),
376         delta_(impl.delta_),
377         rmeps_state_(
378             *fst_, &distance_,
379             RmEpsilonOptions<Arc, FifoQueue<StateId>>(&queue_, delta_, false)) {
380     SetType("rmepsilon");
381     SetProperties(impl.Properties(), kCopyProperties);
382     SetInputSymbols(impl.InputSymbols());
383     SetOutputSymbols(impl.OutputSymbols());
384   }
385
386   StateId Start() {
387     if (!HasStart()) SetStart(fst_->Start());
388     return CacheImpl<Arc>::Start();
389   }
390
391   Weight Final(StateId s) {
392     if (!HasFinal(s)) Expand(s);
393     return CacheImpl<Arc>::Final(s);
394   }
395
396   size_t NumArcs(StateId s) {
397     if (!HasArcs(s)) Expand(s);
398     return CacheImpl<Arc>::NumArcs(s);
399   }
400
401   size_t NumInputEpsilons(StateId s) {
402     if (!HasArcs(s)) Expand(s);
403     return CacheImpl<Arc>::NumInputEpsilons(s);
404   }
405
406   size_t NumOutputEpsilons(StateId s) {
407     if (!HasArcs(s)) Expand(s);
408     return CacheImpl<Arc>::NumOutputEpsilons(s);
409   }
410
411   uint64 Properties() const override { return Properties(kFstProperties); }
412
413   // Sets error if found and returns other FST impl properties.
414   uint64 Properties(uint64 mask) const override {
415     if ((mask & kError) &&
416         (fst_->Properties(kError, false) || rmeps_state_.Error())) {
417       SetProperties(kError, kError);
418     }
419     return FstImpl<Arc>::Properties(mask);
420   }
421
422   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
423     if (!HasArcs(s)) Expand(s);
424     CacheImpl<Arc>::InitArcIterator(s, data);
425   }
426
427   void Expand(StateId s) {
428     rmeps_state_.Expand(s);
429     SetFinal(s, rmeps_state_.Final());
430     auto &arcs = rmeps_state_.Arcs();
431     while (!arcs.empty()) {
432       PushArc(s, arcs.back());
433       arcs.pop_back();
434     }
435     SetArcs(s);
436   }
437
438  private:
439   std::unique_ptr<const Fst<Arc>> fst_;
440   float delta_;
441   std::vector<Weight> distance_;
442   FifoQueue<StateId> queue_;
443   internal::RmEpsilonState<Arc, FifoQueue<StateId>> rmeps_state_;
444 };
445
446 }  // namespace internal
447
448 // Removes epsilon-transitions (when both the input and output label are an
449 // epsilon) from a transducer. The result will be an equivalent FST that has no
450 // such epsilon transitions. This version is a
451 // delayed FST.
452 //
453 // Complexity:
454 //
455 // - Time:
456 //   Unweighted: O(v^2 + ve).
457 //   General: exponential.
458 //
459 // - Space: O(vE)
460 //
461 // where v is the number of states visited and e is the number of arcs visited.
462 // Constant time to visit an input state or arc is assumed and exclusive of
463 // caching.
464 //
465 // For more information, see:
466 //
467 // Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization
468 // algorithms for weighted transducers. International Journal of Computer
469 // Science 13(1): 129-143.
470 //
471 // This class attaches interface to implementation and handles
472 // reference counting, delegating most methods to ImplToFst.
473 template <class A>
474 class RmEpsilonFst : public ImplToFst<internal::RmEpsilonFstImpl<A>> {
475  public:
476   using Arc = A;
477   using StateId = typename Arc::StateId;
478
479   using Store = DefaultCacheStore<Arc>;
480   using State = typename Store::State;
481   using Impl = internal::RmEpsilonFstImpl<Arc>;
482
483   friend class ArcIterator<RmEpsilonFst<Arc>>;
484   friend class StateIterator<RmEpsilonFst<Arc>>;
485
486   explicit RmEpsilonFst(const Fst<Arc> &fst)
487       : ImplToFst<Impl>(std::make_shared<Impl>(fst, RmEpsilonFstOptions())) {}
488
489   RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts)
490       : ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
491
492   // See Fst<>::Copy() for doc.
493   RmEpsilonFst(const RmEpsilonFst<Arc> &fst, bool safe = false)
494       : ImplToFst<Impl>(fst, safe) {}
495
496   // Get a copy of this RmEpsilonFst. See Fst<>::Copy() for further doc.
497   RmEpsilonFst<Arc> *Copy(bool safe = false) const override {
498     return new RmEpsilonFst<Arc>(*this, safe);
499   }
500
501   inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
502
503   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
504     GetMutableImpl()->InitArcIterator(s, data);
505   }
506
507  private:
508   using ImplToFst<Impl>::GetImpl;
509   using ImplToFst<Impl>::GetMutableImpl;
510
511   RmEpsilonFst &operator=(const RmEpsilonFst &) = delete;
512 };
513
514 // Specialization for RmEpsilonFst.
515 template <class Arc>
516 class StateIterator<RmEpsilonFst<Arc>>
517     : public CacheStateIterator<RmEpsilonFst<Arc>> {
518  public:
519   explicit StateIterator(const RmEpsilonFst<Arc> &fst)
520       : CacheStateIterator<RmEpsilonFst<Arc>>(fst, fst.GetMutableImpl()) {}
521 };
522
523 // Specialization for RmEpsilonFst.
524 template <class Arc>
525 class ArcIterator<RmEpsilonFst<Arc>>
526     : public CacheArcIterator<RmEpsilonFst<Arc>> {
527  public:
528   using StateId = typename Arc::StateId;
529
530   ArcIterator(const RmEpsilonFst<Arc> &fst, StateId s)
531       : CacheArcIterator<RmEpsilonFst<Arc>>(fst.GetMutableImpl(), s) {
532     if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
533   }
534 };
535
536 template <class Arc>
537 inline void RmEpsilonFst<Arc>::InitStateIterator(
538     StateIteratorData<Arc> *data) const {
539   data->base = new StateIterator<RmEpsilonFst<Arc>>(*this);
540 }
541
542 // Useful alias when using StdArc.
543 using StdRmEpsilonFst = RmEpsilonFst<StdArc>;
544
545 }  // namespace fst
546
547 #endif  // FST_LIB_RMEPSILON_H_