1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Functions and classes that implemement epsilon-removal.
6 #ifndef FST_LIB_RMEPSILON_H_
7 #define FST_LIB_RMEPSILON_H_
9 #include <forward_list>
12 #include <unordered_map>
18 #include <fst/arcfilter.h>
19 #include <fst/cache.h>
20 #include <fst/connect.h>
21 #include <fst/factor-weight.h>
22 #include <fst/invert.h>
23 #include <fst/prune.h>
24 #include <fst/queue.h>
25 #include <fst/shortest-distance.h>
26 #include <fst/topsort.h>
31 template <class Arc, class Queue>
32 class RmEpsilonOptions
33 : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc>> {
35 using StateId = typename Arc::StateId;
36 using Weight = typename Arc::Weight;
38 bool connect; // Connect output
39 Weight weight_threshold; // Pruning weight threshold.
40 StateId state_threshold; // Pruning state threshold.
42 explicit RmEpsilonOptions(Queue *queue, float delta = kDelta,
44 Weight weight_threshold = Weight::Zero(),
45 StateId state_threshold = kNoStateId)
46 : ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc>>(
47 queue, EpsilonArcFilter<Arc>(), kNoStateId, delta),
49 weight_threshold(std::move(weight_threshold)),
50 state_threshold(state_threshold) {}
55 // Computation state of the epsilon-removal algorithm.
56 template <class Arc, class Queue>
57 class RmEpsilonState {
59 using Label = typename Arc::Label;
60 using StateId = typename Arc::StateId;
61 using Weight = typename Arc::Weight;
63 RmEpsilonState(const Fst<Arc> &fst, std::vector<Weight> *distance,
64 const RmEpsilonOptions<Arc, Queue> &opts)
67 sd_state_(fst_, distance, opts, true),
70 void Expand(StateId s);
72 std::vector<Arc> &Arcs() { return arcs_; }
74 const Weight &Final() const { return final_; }
76 bool Error() const { return sd_state_.Error(); }
86 Element(Label ilabel, Label olabel, StateId nexstate)
87 : ilabel(ilabel), olabel(olabel), nextstate(nexstate) {}
92 size_t operator()(const Element &element) const {
93 static constexpr size_t prime0 = 7853;
94 static constexpr size_t prime1 = 7867;
95 return static_cast<size_t>(element.nextstate) +
96 static_cast<size_t>(element.ilabel) * prime0 +
97 static_cast<size_t>(element.olabel) * prime1;
103 bool operator()(const Element &e1, const Element &e2) const {
104 return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel) &&
105 (e1.nextstate == e2.nextstate);
109 using ElementMap = std::unordered_map<Element, std::pair<StateId, size_t>,
110 ElementHash, ElementEqual>;
112 const Fst<Arc> &fst_;
113 // Distance from state being expanded in epsilon-closure.
114 std::vector<Weight> *distance_;
115 // Shortest distance algorithm computation state.
116 internal::ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc>> sd_state_;
117 // Maps an element to a pair corresponding to a position in the arcs vector
118 // of the state being expanded. The element corresopnds to the position in
119 // the arcs_ vector if p.first is equal to the state being expanded.
120 ElementMap element_map_;
121 EpsilonArcFilter<Arc> eps_filter_;
122 std::stack<StateId> eps_queue_; // Queue used to visit the epsilon-closure.
123 std::vector<bool> visited_; // True if the state has been visited.
124 std::forward_list<StateId> visited_states_; // List of visited states.
125 std::vector<Arc> arcs_; // Arcs of state being expanded
126 Weight final_; // Final weight of state being expanded.
127 StateId expand_id_; // Unique ID for each call to Expand
129 RmEpsilonState(const RmEpsilonState &) = delete;
130 RmEpsilonState &operator=(const RmEpsilonState &) = delete;
133 template <class Arc, class Queue>
134 void RmEpsilonState<Arc, Queue>::Expand(typename Arc::StateId source) {
135 final_ = Weight::Zero();
137 sd_state_.ShortestDistance(source);
138 if (sd_state_.Error()) return;
139 eps_queue_.push(source);
140 while (!eps_queue_.empty()) {
141 const auto state = eps_queue_.top();
143 while (visited_.size() <= state) visited_.push_back(false);
144 if (visited_[state]) continue;
145 visited_[state] = true;
146 visited_states_.push_front(state);
147 for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
149 auto arc = aiter.Value();
150 arc.weight = Times((*distance_)[state], arc.weight);
151 if (eps_filter_(arc)) {
152 while (visited_.size() <= arc.nextstate) visited_.push_back(false);
153 if (!visited_[arc.nextstate]) eps_queue_.push(arc.nextstate);
155 const Element element(arc.ilabel, arc.olabel, arc.nextstate);
156 auto insert_result = element_map_.insert(
157 std::make_pair(element, std::make_pair(expand_id_, arcs_.size())));
158 if (insert_result.second) {
159 arcs_.push_back(arc);
161 if (insert_result.first->second.first == expand_id_) {
162 auto &weight = arcs_[insert_result.first->second.second].weight;
163 weight = Plus(weight, arc.weight);
165 insert_result.first->second.first = expand_id_;
166 insert_result.first->second.second = arcs_.size();
167 arcs_.push_back(arc);
172 final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state)));
174 while (!visited_states_.empty()) {
175 visited_[visited_states_.front()] = false;
176 visited_states_.pop_front();
181 } // namespace internal
183 // Removes epsilon-transitions (when both the input and output label are an
184 // epsilon) from a transducer. The result will be an equivalent FST that has no
185 // such epsilon transitions. This version modifies its input. It allows fine
186 // control via the options argument; see below for a simpler interface.
188 // Thei distance vector will be used to hold the shortest distances during the
189 // epsilon-closure computation. The state queue discipline and convergence delta
190 // are taken in the options argument.
191 template <class Arc, class Queue>
192 void RmEpsilon(MutableFst<Arc> *fst,
193 std::vector<typename Arc::Weight> *distance,
194 const RmEpsilonOptions<Arc, Queue> &opts) {
195 using Label = typename Arc::Label;
196 using StateId = typename Arc::StateId;
197 using Weight = typename Arc::Weight;
198 if (fst->Start() == kNoStateId) return;
199 // noneps_in[s] will be set to true iff s'admits a non-epsilon incoming
200 // transition or is the start state.
201 std::vector<bool> noneps_in(fst->NumStates(), false);
202 noneps_in[fst->Start()] = true;
203 for (size_t i = 0; i < fst->NumStates(); ++i) {
204 for (ArcIterator<Fst<Arc>> aiter(*fst, i); !aiter.Done(); aiter.Next()) {
205 const auto &arc = aiter.Value();
206 if (arc.ilabel != 0 || arc.olabel != 0) {
207 noneps_in[arc.nextstate] = true;
211 // States sorted in topological order when (acyclic) or generic topological
213 std::vector<StateId> states;
214 states.reserve(fst->NumStates());
215 if (fst->Properties(kTopSorted, false) & kTopSorted) {
216 for (size_t i = 0; i < fst->NumStates(); i++) states.push_back(i);
217 } else if (fst->Properties(kAcyclic, false) & kAcyclic) {
218 std::vector<StateId> order;
220 TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
221 DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>());
222 // Sanity check: should be acyclic if property bit is set.
224 FSTERROR() << "RmEpsilon: Inconsistent acyclic property bit";
225 fst->SetProperties(kError, kError);
228 states.resize(order.size());
229 for (StateId i = 0; i < order.size(); i++) states[order[i]] = i;
232 std::vector<StateId> scc;
233 SccVisitor<Arc> scc_visitor(&scc, nullptr, nullptr, &props);
234 DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>());
235 std::vector<StateId> first(scc.size(), kNoStateId);
236 std::vector<StateId> next(scc.size(), kNoStateId);
237 for (StateId i = 0; i < scc.size(); i++) {
238 if (first[scc[i]] != kNoStateId) next[i] = first[scc[i]];
241 for (StateId i = 0; i < first.size(); i++) {
242 for (auto j = first[i]; j != kNoStateId; j = next[j]) {
247 internal::RmEpsilonState<Arc, Queue> rmeps_state(*fst, distance, opts);
248 while (!states.empty()) {
249 const auto state = states.back();
251 if (!noneps_in[state] &&
252 (opts.connect || opts.weight_threshold != Weight::Zero() ||
253 opts.state_threshold != kNoStateId)) {
256 rmeps_state.Expand(state);
257 fst->SetFinal(state, rmeps_state.Final());
258 fst->DeleteArcs(state);
259 auto &arcs = rmeps_state.Arcs();
260 fst->ReserveArcs(state, arcs.size());
261 while (!arcs.empty()) {
262 fst->AddArc(state, arcs.back());
266 if (opts.connect || opts.weight_threshold != Weight::Zero() ||
267 opts.state_threshold != kNoStateId) {
268 for (size_t s = 0; s < fst->NumStates(); ++s) {
269 if (!noneps_in[s]) fst->DeleteArcs(s);
272 if (rmeps_state.Error()) fst->SetProperties(kError, kError);
274 RmEpsilonProperties(fst->Properties(kFstProperties, false)),
276 if (opts.weight_threshold != Weight::Zero() ||
277 opts.state_threshold != kNoStateId) {
278 Prune(fst, opts.weight_threshold, opts.state_threshold);
280 if (opts.connect && opts.weight_threshold == Weight::Zero() &&
281 opts.state_threshold == kNoStateId) {
286 // Removes epsilon-transitions (when both the input and output label
287 // are an epsilon) from a transducer. The result will be an equivalent
288 // FST that has no such epsilon transitions. This version modifies its
289 // input. It has a simplified interface; see above for a version that
290 // allows finer control.
296 // Unweighted: O(v^2 + ve).
297 // Acyclic: O(v^2 + V e).
298 // Tropical semiring: O(v^2 log V + ve).
299 // General: exponential.
303 // where v is the number of states visited and e is the number of arcs visited.
305 // For more information, see:
307 // Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization
308 // algorithms for weighted transducers. International Journal of Computer
309 // Science 13(1): 129-143.
311 void RmEpsilon(MutableFst<Arc> *fst, bool connect = true,
312 typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
313 typename Arc::StateId state_threshold = kNoStateId,
314 float delta = kDelta) {
315 using StateId = typename Arc::StateId;
316 using Weight = typename Arc::Weight;
317 std::vector<Weight> distance;
318 AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>());
319 RmEpsilonOptions<Arc, AutoQueue<StateId>> opts(
320 &state_queue, delta, connect, weight_threshold, state_threshold);
321 RmEpsilon(fst, &distance, opts);
324 struct RmEpsilonFstOptions : CacheOptions {
327 explicit RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta)
328 : CacheOptions(opts), delta(delta) {}
330 explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {}
335 // Implementation of delayed RmEpsilonFst.
337 class RmEpsilonFstImpl : public CacheImpl<Arc> {
339 using StateId = typename Arc::StateId;
340 using Weight = typename Arc::Weight;
342 using Store = DefaultCacheStore<Arc>;
343 using State = typename Store::State;
345 using FstImpl<Arc>::Properties;
346 using FstImpl<Arc>::SetType;
347 using FstImpl<Arc>::SetProperties;
348 using FstImpl<Arc>::SetInputSymbols;
349 using FstImpl<Arc>::SetOutputSymbols;
351 using CacheBaseImpl<CacheState<Arc>>::HasArcs;
352 using CacheBaseImpl<CacheState<Arc>>::HasFinal;
353 using CacheBaseImpl<CacheState<Arc>>::HasStart;
354 using CacheBaseImpl<CacheState<Arc>>::PushArc;
355 using CacheBaseImpl<CacheState<Arc>>::SetArcs;
356 using CacheBaseImpl<CacheState<Arc>>::SetFinal;
357 using CacheBaseImpl<CacheState<Arc>>::SetStart;
359 RmEpsilonFstImpl(const Fst<Arc> &fst, const RmEpsilonFstOptions &opts)
360 : CacheImpl<Arc>(opts),
365 RmEpsilonOptions<Arc, FifoQueue<StateId>>(&queue_, delta_, false)) {
366 SetType("rmepsilon");
368 RmEpsilonProperties(fst.Properties(kFstProperties, false), true),
370 SetInputSymbols(fst.InputSymbols());
371 SetOutputSymbols(fst.OutputSymbols());
374 RmEpsilonFstImpl(const RmEpsilonFstImpl &impl)
375 : CacheImpl<Arc>(impl),
376 fst_(impl.fst_->Copy(true)),
380 RmEpsilonOptions<Arc, FifoQueue<StateId>>(&queue_, delta_, false)) {
381 SetType("rmepsilon");
382 SetProperties(impl.Properties(), kCopyProperties);
383 SetInputSymbols(impl.InputSymbols());
384 SetOutputSymbols(impl.OutputSymbols());
388 if (!HasStart()) SetStart(fst_->Start());
389 return CacheImpl<Arc>::Start();
392 Weight Final(StateId s) {
393 if (!HasFinal(s)) Expand(s);
394 return CacheImpl<Arc>::Final(s);
397 size_t NumArcs(StateId s) {
398 if (!HasArcs(s)) Expand(s);
399 return CacheImpl<Arc>::NumArcs(s);
402 size_t NumInputEpsilons(StateId s) {
403 if (!HasArcs(s)) Expand(s);
404 return CacheImpl<Arc>::NumInputEpsilons(s);
407 size_t NumOutputEpsilons(StateId s) {
408 if (!HasArcs(s)) Expand(s);
409 return CacheImpl<Arc>::NumOutputEpsilons(s);
412 uint64 Properties() const override { return Properties(kFstProperties); }
414 // Sets error if found and returns other FST impl properties.
415 uint64 Properties(uint64 mask) const override {
416 if ((mask & kError) &&
417 (fst_->Properties(kError, false) || rmeps_state_.Error())) {
418 SetProperties(kError, kError);
420 return FstImpl<Arc>::Properties(mask);
423 void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
424 if (!HasArcs(s)) Expand(s);
425 CacheImpl<Arc>::InitArcIterator(s, data);
428 void Expand(StateId s) {
429 rmeps_state_.Expand(s);
430 SetFinal(s, rmeps_state_.Final());
431 auto &arcs = rmeps_state_.Arcs();
432 while (!arcs.empty()) {
433 PushArc(s, arcs.back());
440 std::unique_ptr<const Fst<Arc>> fst_;
442 std::vector<Weight> distance_;
443 FifoQueue<StateId> queue_;
444 internal::RmEpsilonState<Arc, FifoQueue<StateId>> rmeps_state_;
447 } // namespace internal
449 // Removes epsilon-transitions (when both the input and output label are an
450 // epsilon) from a transducer. The result will be an equivalent FST that has no
451 // such epsilon transitions. This version is a
457 // Unweighted: O(v^2 + ve).
458 // General: exponential.
462 // where v is the number of states visited and e is the number of arcs visited.
463 // Constant time to visit an input state or arc is assumed and exclusive of
466 // For more information, see:
468 // Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization
469 // algorithms for weighted transducers. International Journal of Computer
470 // Science 13(1): 129-143.
472 // This class attaches interface to implementation and handles
473 // reference counting, delegating most methods to ImplToFst.
475 class RmEpsilonFst : public ImplToFst<internal::RmEpsilonFstImpl<A>> {
478 using StateId = typename Arc::StateId;
480 using Store = DefaultCacheStore<Arc>;
481 using State = typename Store::State;
482 using Impl = internal::RmEpsilonFstImpl<Arc>;
484 friend class ArcIterator<RmEpsilonFst<Arc>>;
485 friend class StateIterator<RmEpsilonFst<Arc>>;
487 explicit RmEpsilonFst(const Fst<Arc> &fst)
488 : ImplToFst<Impl>(std::make_shared<Impl>(fst, RmEpsilonFstOptions())) {}
490 RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts)
491 : ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
493 // See Fst<>::Copy() for doc.
494 RmEpsilonFst(const RmEpsilonFst<Arc> &fst, bool safe = false)
495 : ImplToFst<Impl>(fst, safe) {}
497 // Get a copy of this RmEpsilonFst. See Fst<>::Copy() for further doc.
498 RmEpsilonFst<Arc> *Copy(bool safe = false) const override {
499 return new RmEpsilonFst<Arc>(*this, safe);
502 inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
504 void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
505 GetMutableImpl()->InitArcIterator(s, data);
509 using ImplToFst<Impl>::GetImpl;
510 using ImplToFst<Impl>::GetMutableImpl;
512 RmEpsilonFst &operator=(const RmEpsilonFst &) = delete;
515 // Specialization for RmEpsilonFst.
517 class StateIterator<RmEpsilonFst<Arc>>
518 : public CacheStateIterator<RmEpsilonFst<Arc>> {
520 explicit StateIterator(const RmEpsilonFst<Arc> &fst)
521 : CacheStateIterator<RmEpsilonFst<Arc>>(fst, fst.GetMutableImpl()) {}
524 // Specialization for RmEpsilonFst.
526 class ArcIterator<RmEpsilonFst<Arc>>
527 : public CacheArcIterator<RmEpsilonFst<Arc>> {
529 using StateId = typename Arc::StateId;
531 ArcIterator(const RmEpsilonFst<Arc> &fst, StateId s)
532 : CacheArcIterator<RmEpsilonFst<Arc>>(fst.GetMutableImpl(), s) {
533 if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
538 inline void RmEpsilonFst<Arc>::InitStateIterator(
539 StateIteratorData<Arc> *data) const {
540 data->base = new StateIterator<RmEpsilonFst<Arc>>(*this);
543 // Useful alias when using StdArc.
544 using StdRmEpsilonFst = RmEpsilonFst<StdArc>;
548 #endif // FST_LIB_RMEPSILON_H_