1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Functions and classes that implemement epsilon-removal.
6 #ifndef FST_LIB_RMEPSILON_H_
7 #define FST_LIB_RMEPSILON_H_
9 #include <forward_list>
12 #include <unordered_map>
18 #include <fst/arcfilter.h>
19 #include <fst/cache.h>
20 #include <fst/connect.h>
21 #include <fst/factor-weight.h>
22 #include <fst/invert.h>
23 #include <fst/prune.h>
24 #include <fst/queue.h>
25 #include <fst/shortest-distance.h>
26 #include <fst/topsort.h>
31 template <class Arc, class Queue>
32 class RmEpsilonOptions
33 : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc>> {
35 using StateId = typename Arc::StateId;
36 using Weight = typename Arc::Weight;
38 bool connect; // Connect output
39 Weight weight_threshold; // Pruning weight threshold.
40 StateId state_threshold; // Pruning state threshold.
42 explicit RmEpsilonOptions(Queue *queue, float delta = kDelta,
44 Weight weight_threshold = Weight::Zero(),
45 StateId state_threshold = kNoStateId)
46 : ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc>>(
47 queue, EpsilonArcFilter<Arc>(), kNoStateId, delta),
49 weight_threshold(std::move(weight_threshold)),
50 state_threshold(state_threshold) {}
55 // Computation state of the epsilon-removal algorithm.
56 template <class Arc, class Queue>
57 class RmEpsilonState {
59 using Label = typename Arc::Label;
60 using StateId = typename Arc::StateId;
61 using Weight = typename Arc::Weight;
63 RmEpsilonState(const Fst<Arc> &fst, std::vector<Weight> *distance,
64 const RmEpsilonOptions<Arc, Queue> &opts)
67 sd_state_(fst_, distance, opts, true),
70 void Expand(StateId s);
72 std::vector<Arc> &Arcs() { return arcs_; }
74 const Weight &Final() const { return final_; }
76 bool Error() const { return sd_state_.Error(); }
86 Element(Label ilabel, Label olabel, StateId nexstate)
87 : ilabel(ilabel), olabel(olabel), nextstate(nexstate) {}
92 size_t operator()(const Element &element) const {
93 static constexpr auto prime0 = 7853;
94 static constexpr auto prime1 = 7867;
95 return element.nextstate + element.ilabel * prime0 +
96 element.olabel * prime1;
102 bool operator()(const Element &e1, const Element &e2) const {
103 return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel) &&
104 (e1.nextstate == e2.nextstate);
108 using ElementMap = std::unordered_map<Element, std::pair<StateId, size_t>,
109 ElementHash, ElementEqual>;
111 const Fst<Arc> &fst_;
112 // Distance from state being expanded in epsilon-closure.
113 std::vector<Weight> *distance_;
114 // Shortest distance algorithm computation state.
115 internal::ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc>> sd_state_;
116 // Maps an element to a pair corresponding to a position in the arcs vector
117 // of the state being expanded. The element corresopnds to the position in
118 // the arcs_ vector if p.first is equal to the state being expanded.
119 ElementMap element_map_;
120 EpsilonArcFilter<Arc> eps_filter_;
121 std::stack<StateId> eps_queue_; // Queue used to visit the epsilon-closure.
122 std::vector<bool> visited_; // True if the state has been visited.
123 std::forward_list<StateId> visited_states_; // List of visited states.
124 std::vector<Arc> arcs_; // Arcs of state being expanded
125 Weight final_; // Final weight of state being expanded.
126 StateId expand_id_; // Unique ID for each call to Expand
128 RmEpsilonState(const RmEpsilonState &) = delete;
129 RmEpsilonState &operator=(const RmEpsilonState &) = delete;
132 template <class Arc, class Queue>
133 void RmEpsilonState<Arc, Queue>::Expand(typename Arc::StateId source) {
134 final_ = Weight::Zero();
136 sd_state_.ShortestDistance(source);
137 if (sd_state_.Error()) return;
138 eps_queue_.push(source);
139 while (!eps_queue_.empty()) {
140 const auto state = eps_queue_.top();
142 while (visited_.size() <= state) visited_.push_back(false);
143 if (visited_[state]) continue;
144 visited_[state] = true;
145 visited_states_.push_front(state);
146 for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
148 auto arc = aiter.Value();
149 arc.weight = Times((*distance_)[state], arc.weight);
150 if (eps_filter_(arc)) {
151 while (visited_.size() <= arc.nextstate) visited_.push_back(false);
152 if (!visited_[arc.nextstate]) eps_queue_.push(arc.nextstate);
154 const Element element(arc.ilabel, arc.olabel, arc.nextstate);
155 auto insert_result = element_map_.insert(
156 std::make_pair(element, std::make_pair(expand_id_, arcs_.size())));
157 if (insert_result.second) {
158 arcs_.push_back(arc);
160 if (insert_result.first->second.first == expand_id_) {
161 auto &weight = arcs_[insert_result.first->second.second].weight;
162 weight = Plus(weight, arc.weight);
164 insert_result.first->second.first = expand_id_;
165 insert_result.first->second.second = arcs_.size();
166 arcs_.push_back(arc);
171 final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state)));
173 while (!visited_states_.empty()) {
174 visited_[visited_states_.front()] = false;
175 visited_states_.pop_front();
180 } // namespace internal
182 // Removes epsilon-transitions (when both the input and output label are an
183 // epsilon) from a transducer. The result will be an equivalent FST that has no
184 // such epsilon transitions. This version modifies its input. It allows fine
185 // control via the options argument; see below for a simpler interface.
187 // Thei distance vector will be used to hold the shortest distances during the
188 // epsilon-closure computation. The state queue discipline and convergence delta
189 // are taken in the options argument.
190 template <class Arc, class Queue>
191 void RmEpsilon(MutableFst<Arc> *fst,
192 std::vector<typename Arc::Weight> *distance,
193 const RmEpsilonOptions<Arc, Queue> &opts) {
194 using Label = typename Arc::Label;
195 using StateId = typename Arc::StateId;
196 using Weight = typename Arc::Weight;
197 if (fst->Start() == kNoStateId) return;
198 // noneps_in[s] will be set to true iff s'admits a non-epsilon incoming
199 // transition or is the start state.
200 std::vector<bool> noneps_in(fst->NumStates(), false);
201 noneps_in[fst->Start()] = true;
202 for (size_t i = 0; i < fst->NumStates(); ++i) {
203 for (ArcIterator<Fst<Arc>> aiter(*fst, i); !aiter.Done(); aiter.Next()) {
204 const auto &arc = aiter.Value();
205 if (arc.ilabel != 0 || arc.olabel != 0) {
206 noneps_in[arc.nextstate] = true;
210 // States sorted in topological order when (acyclic) or generic topological
212 std::vector<StateId> states;
213 states.reserve(fst->NumStates());
214 if (fst->Properties(kTopSorted, false) & kTopSorted) {
215 for (size_t i = 0; i < fst->NumStates(); i++) states.push_back(i);
216 } else if (fst->Properties(kAcyclic, false) & kAcyclic) {
217 std::vector<StateId> order;
219 TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
220 DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>());
221 // Sanity check: should be acyclic if property bit is set.
223 FSTERROR() << "RmEpsilon: Inconsistent acyclic property bit";
224 fst->SetProperties(kError, kError);
227 states.resize(order.size());
228 for (StateId i = 0; i < order.size(); i++) states[order[i]] = i;
231 std::vector<StateId> scc;
232 SccVisitor<Arc> scc_visitor(&scc, nullptr, nullptr, &props);
233 DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>());
234 std::vector<StateId> first(scc.size(), kNoStateId);
235 std::vector<StateId> next(scc.size(), kNoStateId);
236 for (StateId i = 0; i < scc.size(); i++) {
237 if (first[scc[i]] != kNoStateId) next[i] = first[scc[i]];
240 for (StateId i = 0; i < first.size(); i++) {
241 for (auto j = first[i]; j != kNoStateId; j = next[j]) {
246 internal::RmEpsilonState<Arc, Queue> rmeps_state(*fst, distance, opts);
247 while (!states.empty()) {
248 const auto state = states.back();
250 if (!noneps_in[state] &&
251 (opts.connect || opts.weight_threshold != Weight::Zero() ||
252 opts.state_threshold != kNoStateId)) {
255 rmeps_state.Expand(state);
256 fst->SetFinal(state, rmeps_state.Final());
257 fst->DeleteArcs(state);
258 auto &arcs = rmeps_state.Arcs();
259 fst->ReserveArcs(state, arcs.size());
260 while (!arcs.empty()) {
261 fst->AddArc(state, arcs.back());
265 if (opts.connect || opts.weight_threshold != Weight::Zero() ||
266 opts.state_threshold != kNoStateId) {
267 for (size_t s = 0; s < fst->NumStates(); ++s) {
268 if (!noneps_in[s]) fst->DeleteArcs(s);
271 if (rmeps_state.Error()) fst->SetProperties(kError, kError);
273 RmEpsilonProperties(fst->Properties(kFstProperties, false)),
275 if (opts.weight_threshold != Weight::Zero() ||
276 opts.state_threshold != kNoStateId) {
277 Prune(fst, opts.weight_threshold, opts.state_threshold);
279 if (opts.connect && opts.weight_threshold == Weight::Zero() &&
280 opts.state_threshold == kNoStateId) {
285 // Removes epsilon-transitions (when both the input and output label
286 // are an epsilon) from a transducer. The result will be an equivalent
287 // FST that has no such epsilon transitions. This version modifies its
288 // input. It has a simplified interface; see above for a version that
289 // allows finer control.
295 // Unweighted: O(v^2 + ve).
296 // Acyclic: O(v^2 + V e).
297 // Tropical semiring: O(v^2 log V + ve).
298 // General: exponential.
302 // where v is the number of states visited and e is the number of arcs visited.
304 // For more information, see:
306 // Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization
307 // algorithms for weighted transducers. International Journal of Computer
308 // Science 13(1): 129-143.
310 void RmEpsilon(MutableFst<Arc> *fst, bool connect = true,
311 typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
312 typename Arc::StateId state_threshold = kNoStateId,
313 float delta = kDelta) {
314 using StateId = typename Arc::StateId;
315 using Weight = typename Arc::Weight;
316 std::vector<Weight> distance;
317 AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>());
318 RmEpsilonOptions<Arc, AutoQueue<StateId>> opts(
319 &state_queue, delta, connect, weight_threshold, state_threshold);
320 RmEpsilon(fst, &distance, opts);
323 struct RmEpsilonFstOptions : CacheOptions {
326 explicit RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta)
327 : CacheOptions(opts), delta(delta) {}
329 explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {}
334 // Implementation of delayed RmEpsilonFst.
336 class RmEpsilonFstImpl : public CacheImpl<Arc> {
338 using StateId = typename Arc::StateId;
339 using Weight = typename Arc::Weight;
341 using Store = DefaultCacheStore<Arc>;
342 using State = typename Store::State;
344 using FstImpl<Arc>::Properties;
345 using FstImpl<Arc>::SetType;
346 using FstImpl<Arc>::SetProperties;
347 using FstImpl<Arc>::SetInputSymbols;
348 using FstImpl<Arc>::SetOutputSymbols;
350 using CacheBaseImpl<CacheState<Arc>>::HasArcs;
351 using CacheBaseImpl<CacheState<Arc>>::HasFinal;
352 using CacheBaseImpl<CacheState<Arc>>::HasStart;
353 using CacheBaseImpl<CacheState<Arc>>::PushArc;
354 using CacheBaseImpl<CacheState<Arc>>::SetArcs;
355 using CacheBaseImpl<CacheState<Arc>>::SetFinal;
356 using CacheBaseImpl<CacheState<Arc>>::SetStart;
358 RmEpsilonFstImpl(const Fst<Arc> &fst, const RmEpsilonFstOptions &opts)
359 : CacheImpl<Arc>(opts),
364 RmEpsilonOptions<Arc, FifoQueue<StateId>>(&queue_, delta_, false)) {
365 SetType("rmepsilon");
367 RmEpsilonProperties(fst.Properties(kFstProperties, false), true),
369 SetInputSymbols(fst.InputSymbols());
370 SetOutputSymbols(fst.OutputSymbols());
373 RmEpsilonFstImpl(const RmEpsilonFstImpl &impl)
374 : CacheImpl<Arc>(impl),
375 fst_(impl.fst_->Copy(true)),
379 RmEpsilonOptions<Arc, FifoQueue<StateId>>(&queue_, delta_, false)) {
380 SetType("rmepsilon");
381 SetProperties(impl.Properties(), kCopyProperties);
382 SetInputSymbols(impl.InputSymbols());
383 SetOutputSymbols(impl.OutputSymbols());
387 if (!HasStart()) SetStart(fst_->Start());
388 return CacheImpl<Arc>::Start();
391 Weight Final(StateId s) {
392 if (!HasFinal(s)) Expand(s);
393 return CacheImpl<Arc>::Final(s);
396 size_t NumArcs(StateId s) {
397 if (!HasArcs(s)) Expand(s);
398 return CacheImpl<Arc>::NumArcs(s);
401 size_t NumInputEpsilons(StateId s) {
402 if (!HasArcs(s)) Expand(s);
403 return CacheImpl<Arc>::NumInputEpsilons(s);
406 size_t NumOutputEpsilons(StateId s) {
407 if (!HasArcs(s)) Expand(s);
408 return CacheImpl<Arc>::NumOutputEpsilons(s);
411 uint64 Properties() const override { return Properties(kFstProperties); }
413 // Sets error if found and returns other FST impl properties.
414 uint64 Properties(uint64 mask) const override {
415 if ((mask & kError) &&
416 (fst_->Properties(kError, false) || rmeps_state_.Error())) {
417 SetProperties(kError, kError);
419 return FstImpl<Arc>::Properties(mask);
422 void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
423 if (!HasArcs(s)) Expand(s);
424 CacheImpl<Arc>::InitArcIterator(s, data);
427 void Expand(StateId s) {
428 rmeps_state_.Expand(s);
429 SetFinal(s, rmeps_state_.Final());
430 auto &arcs = rmeps_state_.Arcs();
431 while (!arcs.empty()) {
432 PushArc(s, arcs.back());
439 std::unique_ptr<const Fst<Arc>> fst_;
441 std::vector<Weight> distance_;
442 FifoQueue<StateId> queue_;
443 internal::RmEpsilonState<Arc, FifoQueue<StateId>> rmeps_state_;
446 } // namespace internal
448 // Removes epsilon-transitions (when both the input and output label are an
449 // epsilon) from a transducer. The result will be an equivalent FST that has no
450 // such epsilon transitions. This version is a
456 // Unweighted: O(v^2 + ve).
457 // General: exponential.
461 // where v is the number of states visited and e is the number of arcs visited.
462 // Constant time to visit an input state or arc is assumed and exclusive of
465 // For more information, see:
467 // Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization
468 // algorithms for weighted transducers. International Journal of Computer
469 // Science 13(1): 129-143.
471 // This class attaches interface to implementation and handles
472 // reference counting, delegating most methods to ImplToFst.
474 class RmEpsilonFst : public ImplToFst<internal::RmEpsilonFstImpl<A>> {
477 using StateId = typename Arc::StateId;
479 using Store = DefaultCacheStore<Arc>;
480 using State = typename Store::State;
481 using Impl = internal::RmEpsilonFstImpl<Arc>;
483 friend class ArcIterator<RmEpsilonFst<Arc>>;
484 friend class StateIterator<RmEpsilonFst<Arc>>;
486 explicit RmEpsilonFst(const Fst<Arc> &fst)
487 : ImplToFst<Impl>(std::make_shared<Impl>(fst, RmEpsilonFstOptions())) {}
489 RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts)
490 : ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
492 // See Fst<>::Copy() for doc.
493 RmEpsilonFst(const RmEpsilonFst<Arc> &fst, bool safe = false)
494 : ImplToFst<Impl>(fst, safe) {}
496 // Get a copy of this RmEpsilonFst. See Fst<>::Copy() for further doc.
497 RmEpsilonFst<Arc> *Copy(bool safe = false) const override {
498 return new RmEpsilonFst<Arc>(*this, safe);
501 inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
503 void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
504 GetMutableImpl()->InitArcIterator(s, data);
508 using ImplToFst<Impl>::GetImpl;
509 using ImplToFst<Impl>::GetMutableImpl;
511 RmEpsilonFst &operator=(const RmEpsilonFst &) = delete;
514 // Specialization for RmEpsilonFst.
516 class StateIterator<RmEpsilonFst<Arc>>
517 : public CacheStateIterator<RmEpsilonFst<Arc>> {
519 explicit StateIterator(const RmEpsilonFst<Arc> &fst)
520 : CacheStateIterator<RmEpsilonFst<Arc>>(fst, fst.GetMutableImpl()) {}
523 // Specialization for RmEpsilonFst.
525 class ArcIterator<RmEpsilonFst<Arc>>
526 : public CacheArcIterator<RmEpsilonFst<Arc>> {
528 using StateId = typename Arc::StateId;
530 ArcIterator(const RmEpsilonFst<Arc> &fst, StateId s)
531 : CacheArcIterator<RmEpsilonFst<Arc>>(fst.GetMutableImpl(), s) {
532 if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
537 inline void RmEpsilonFst<Arc>::InitStateIterator(
538 StateIteratorData<Arc> *data) const {
539 data->base = new StateIterator<RmEpsilonFst<Arc>>(*this);
542 // Useful alias when using StdArc.
543 using StdRmEpsilonFst = RmEpsilonFst<StdArc>;
547 #endif // FST_LIB_RMEPSILON_H_