1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Functions and classes to find shortest distance in an FST.
6 #ifndef FST_SHORTEST_DISTANCE_H_
7 #define FST_SHORTEST_DISTANCE_H_
14 #include <fst/arcfilter.h>
15 #include <fst/cache.h>
16 #include <fst/queue.h>
17 #include <fst/reverse.h>
18 #include <fst/test-properties.h>
23 // A representable float for shortest distance and shortest path algorithms.
24 constexpr float kShortestDelta = 1e-6;
26 template <class Arc, class Queue, class ArcFilter>
27 struct ShortestDistanceOptions {
28 using StateId = typename Arc::StateId;
30 Queue *state_queue; // Queue discipline used; owned by caller.
31 ArcFilter arc_filter; // Arc filter (e.g., limit to only epsilon graph).
32 StateId source; // If kNoStateId, use the FST's initial state.
33 float delta; // Determines the degree of convergence required
34 bool first_path; // For a semiring with the path property (o.w.
35 // undefined), compute the shortest-distances along
36 // along the first path to a final state found
37 // by the algorithm. That path is the shortest-path
38 // only if the FST has a unique final state (or all
39 // the final states have the same final weight), the
40 // queue discipline is shortest-first and all the
41 // weights in the FST are between One() and Zero()
42 // according to NaturalLess.
44 ShortestDistanceOptions(Queue *state_queue, ArcFilter arc_filter,
45 StateId source = kNoStateId,
46 float delta = kShortestDelta)
47 : state_queue(state_queue),
48 arc_filter(arc_filter),
56 // Computation state of the shortest-distance algorithm. Reusable information
57 // is maintained across calls to member function ShortestDistance(source) when
58 // retain is true for improved efficiency when calling multiple times from
59 // different source states (e.g., in epsilon removal). Contrary to the usual
60 // conventions, fst may not be freed before this class. Vector distance
61 // should not be modified by the user between these calls. The Error() method
62 // returns true iff an error was encountered.
63 template <class Arc, class Queue, class ArcFilter>
64 class ShortestDistanceState {
66 using StateId = typename Arc::StateId;
67 using Weight = typename Arc::Weight;
69 ShortestDistanceState(
70 const Fst<Arc> &fst, std::vector<Weight> *distance,
71 const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts, bool retain)
74 state_queue_(opts.state_queue),
75 arc_filter_(opts.arc_filter),
77 first_path_(opts.first_path),
84 void ShortestDistance(StateId source);
86 bool Error() const { return error_; }
90 std::vector<Weight> *distance_;
92 ArcFilter arc_filter_;
94 const bool first_path_;
95 const bool retain_; // Retain and reuse information across calls.
97 std::vector<Adder<Weight>> adder_; // Sums distance_ accurately.
98 std::vector<Adder<Weight>> radder_; // Relaxation distance.
99 std::vector<bool> enqueued_; // Is state enqueued?
100 std::vector<StateId> sources_; // Source ID for ith state in distance_,
101 // (r)adder_, and enqueued_ if retained.
102 StateId source_id_; // Unique ID characterizing each call.
106 // Compute the shortest distance; if source is kNoStateId, uses the initial
108 template <class Arc, class Queue, class ArcFilter>
109 void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance(
111 if (fst_.Start() == kNoStateId) {
112 if (fst_.Properties(kError, false)) error_ = true;
115 if (!(Weight::Properties() & kRightSemiring)) {
116 FSTERROR() << "ShortestDistance: Weight needs to be right distributive: "
121 if (first_path_ && !(Weight::Properties() & kPath)) {
122 FSTERROR() << "ShortestDistance: The first_path option is disallowed when "
123 << "Weight does not have the path property: " << Weight::Type();
127 state_queue_->Clear();
134 if (source == kNoStateId) source = fst_.Start();
135 while (distance_->size() <= source) {
136 distance_->push_back(Weight::Zero());
137 adder_.push_back(Adder<Weight>());
138 radder_.push_back(Adder<Weight>());
139 enqueued_.push_back(false);
142 while (sources_.size() <= source) sources_.push_back(kNoStateId);
143 sources_[source] = source_id_;
145 (*distance_)[source] = Weight::One();
146 adder_[source].Reset(Weight::One());
147 radder_[source].Reset(Weight::One());
148 enqueued_[source] = true;
149 state_queue_->Enqueue(source);
150 while (!state_queue_->Empty()) {
151 const auto state = state_queue_->Head();
152 state_queue_->Dequeue();
153 while (distance_->size() <= state) {
154 distance_->push_back(Weight::Zero());
155 adder_.push_back(Adder<Weight>());
156 radder_.push_back(Adder<Weight>());
157 enqueued_.push_back(false);
159 if (first_path_ && (fst_.Final(state) != Weight::Zero())) break;
160 enqueued_[state] = false;
161 const auto r = radder_[state].Sum();
162 radder_[state].Reset();
163 for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
165 const auto &arc = aiter.Value();
166 if (!arc_filter_(arc)) continue;
167 while (distance_->size() <= arc.nextstate) {
168 distance_->push_back(Weight::Zero());
169 adder_.push_back(Adder<Weight>());
170 radder_.push_back(Adder<Weight>());
171 enqueued_.push_back(false);
174 while (sources_.size() <= arc.nextstate) sources_.push_back(kNoStateId);
175 if (sources_[arc.nextstate] != source_id_) {
176 (*distance_)[arc.nextstate] = Weight::Zero();
177 adder_[arc.nextstate].Reset();
178 radder_[arc.nextstate].Reset();
179 enqueued_[arc.nextstate] = false;
180 sources_[arc.nextstate] = source_id_;
183 auto &nd = (*distance_)[arc.nextstate];
184 auto &na = adder_[arc.nextstate];
185 auto &nr = radder_[arc.nextstate];
186 auto weight = Times(r, arc.weight);
187 if (!ApproxEqual(nd, Plus(nd, weight), delta_)) {
190 if (!nd.Member() || !nr.Sum().Member()) {
194 if (!enqueued_[arc.nextstate]) {
195 state_queue_->Enqueue(arc.nextstate);
196 enqueued_[arc.nextstate] = true;
198 state_queue_->Update(arc.nextstate);
204 if (fst_.Properties(kError, false)) error_ = true;
207 } // namespace internal
209 // Shortest-distance algorithm: this version allows fine control
210 // via the options argument. See below for a simpler interface.
212 // This computes the shortest distance from the opts.source state to each
213 // visited state S and stores the value in the distance vector. An
214 // nvisited state S has distance Zero(), which will be stored in the
215 // distance vector if S is less than the maximum visited state. The state
216 // queue discipline, arc filter, and convergence delta are taken in the
217 // options argument. The distance vector will contain a unique element for
218 // which Member() is false if an error was encountered.
220 // The weights must must be right distributive and k-closed (i.e., 1 +
221 // x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
225 // Depends on properties of the semiring and the queue discipline.
227 // For more information, see:
229 // Mohri, M. 2002. Semiring framework and algorithms for shortest-distance
230 // problems, Journal of Automata, Languages and
231 // Combinatorics 7(3): 321-350, 2002.
232 template <class Arc, class Queue, class ArcFilter>
233 void ShortestDistance(
234 const Fst<Arc> &fst, std::vector<typename Arc::Weight> *distance,
235 const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) {
236 internal::ShortestDistanceState<Arc, Queue, ArcFilter> sd_state(fst, distance,
238 sd_state.ShortestDistance(opts.source);
239 if (sd_state.Error()) {
241 distance->resize(1, Arc::Weight::NoWeight());
245 // Shortest-distance algorithm: simplified interface. See above for a version
246 // that permits finer control.
248 // If reverse is false, this computes the shortest distance from the initial
249 // state to each state S and stores the value in the distance vector. If
250 // reverse is true, this computes the shortest distance from each state to the
251 // final states. An unvisited state S has distance Zero(), which will be stored
252 // in the distance vector if S is less than the maximum visited state. The
253 // state queue discipline is automatically-selected. The distance vector will
254 // contain a unique element for which Member() is false if an error was
257 // The weights must must be right (left) distributive if reverse is false (true)
258 // and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
260 // Arc weights must satisfy the property that the sum of the weights of one or
261 // more paths from some state S to T is never Zero(). In particular, arc weights
266 // Depends on properties of the semiring and the queue discipline.
268 // For more information, see:
270 // Mohri, M. 2002. Semiring framework and algorithms for
271 // shortest-distance problems, Journal of Automata, Languages and
272 // Combinatorics 7(3): 321-350, 2002.
274 void ShortestDistance(const Fst<Arc> &fst,
275 std::vector<typename Arc::Weight> *distance,
276 bool reverse = false, float delta = kShortestDelta) {
277 using StateId = typename Arc::StateId;
278 using Weight = typename Arc::Weight;
280 AnyArcFilter<Arc> arc_filter;
281 AutoQueue<StateId> state_queue(fst, distance, arc_filter);
282 const ShortestDistanceOptions<Arc, AutoQueue<StateId>, AnyArcFilter<Arc>>
283 opts(&state_queue, arc_filter, kNoStateId, delta);
284 ShortestDistance(fst, distance, opts);
286 using ReverseArc = ReverseArc<Arc>;
287 using ReverseWeight = typename ReverseArc::Weight;
288 AnyArcFilter<ReverseArc> rarc_filter;
289 VectorFst<ReverseArc> rfst;
291 std::vector<ReverseWeight> rdistance;
292 AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter);
293 const ShortestDistanceOptions<ReverseArc, AutoQueue<StateId>,
294 AnyArcFilter<ReverseArc>>
295 ropts(&state_queue, rarc_filter, kNoStateId, delta);
296 ShortestDistance(rfst, &rdistance, ropts);
298 if (rdistance.size() == 1 && !rdistance[0].Member()) {
299 distance->resize(1, Arc::Weight::NoWeight());
302 while (distance->size() < rdistance.size() - 1) {
303 distance->push_back(rdistance[distance->size() + 1].Reverse());
308 // Return the sum of the weight of all successful paths in an FST, i.e., the
309 // shortest-distance from the initial state to the final states. Returns a
310 // weight such that Member() is false if an error was encountered.
312 typename Arc::Weight ShortestDistance(const Fst<Arc> &fst,
313 float delta = kShortestDelta) {
314 using StateId = typename Arc::StateId;
315 using Weight = typename Arc::Weight;
316 std::vector<Weight> distance;
317 if (Weight::Properties() & kRightSemiring) {
318 ShortestDistance(fst, &distance, false, delta);
319 if (distance.size() == 1 && !distance[0].Member()) {
320 return Arc::Weight::NoWeight();
322 Adder<Weight> adder; // maintains cumulative sum accurately
323 for (StateId state = 0; state < distance.size(); ++state) {
324 adder.Add(Times(distance[state], fst.Final(state)));
328 ShortestDistance(fst, &distance, true, delta);
329 const auto state = fst.Start();
330 if (distance.size() == 1 && !distance[0].Member()) {
331 return Arc::Weight::NoWeight();
333 return state != kNoStateId && state < distance.size() ? distance[state]
340 #endif // FST_SHORTEST_DISTANCE_H_