1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Functions to find shortest paths in an FST.
6 #ifndef FST_LIB_SHORTEST_PATH_H_
7 #define FST_LIB_SHORTEST_PATH_H_
15 #include <fst/cache.h>
16 #include <fst/determinize.h>
17 #include <fst/queue.h>
18 #include <fst/shortest-distance.h>
19 #include <fst/test-properties.h>
24 template <class Arc, class Queue, class ArcFilter>
25 struct ShortestPathOptions
26 : public ShortestDistanceOptions<Arc, Queue, ArcFilter> {
27 using StateId = typename Arc::StateId;
28 using Weight = typename Arc::Weight;
30 int32 nshortest; // Returns n-shortest paths.
31 bool unique; // Only returns paths with distinct input strings.
32 bool has_distance; // Distance vector already contains the
33 // shortest distance from the initial state.
34 bool first_path; // Single shortest path stops after finding the first
35 // path to a final state; that path is the shortest path
36 // only when using the ShortestFirstQueue and
37 // only when all the weights in the FST are between
38 // One() and Zero() according to NaturalLess.
39 Weight weight_threshold; // Pruning weight threshold.
40 StateId state_threshold; // Pruning state threshold.
42 ShortestPathOptions(Queue *queue, ArcFilter filter, int32 nshortest = 1,
43 bool unique = false, bool has_distance = false,
44 float delta = kDelta, bool first_path = false,
45 Weight weight_threshold = Weight::Zero(),
46 StateId state_threshold = kNoStateId)
47 : ShortestDistanceOptions<Arc, Queue, ArcFilter>(queue, filter,
51 has_distance(has_distance),
52 first_path(first_path),
53 weight_threshold(std::move(weight_threshold)),
54 state_threshold(state_threshold) {}
59 constexpr size_t kNoArc = -1;
61 // Helper function for SingleShortestPath building the shortest path as a left-
62 // to-right machine backwards from the best final state. It takes the input
63 // FST passed to SingleShortestPath and the parent vector and f_parent returned
64 // by that function, and builds the result into the provided output mutable FS
65 // This is not normally called by users; see ShortestPath instead.
67 void SingleShortestPathBacktrace(
68 const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
69 const std::vector<std::pair<typename Arc::StateId, size_t>> &parent,
70 typename Arc::StateId f_parent) {
71 using StateId = typename Arc::StateId;
73 ofst->SetInputSymbols(ifst.InputSymbols());
74 ofst->SetOutputSymbols(ifst.OutputSymbols());
75 StateId s_p = kNoStateId;
76 StateId d_p = kNoStateId;
77 for (StateId state = f_parent, d = kNoStateId; state != kNoStateId;
78 d = state, state = parent[state].first) {
80 s_p = ofst->AddState();
81 if (d == kNoStateId) {
82 ofst->SetFinal(s_p, ifst.Final(f_parent));
84 ArcIterator<Fst<Arc>> aiter(ifst, state);
85 aiter.Seek(parent[d].second);
86 auto arc = aiter.Value();
88 ofst->AddArc(s_p, arc);
92 if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
94 ShortestPathProperties(ofst->Properties(kFstProperties, false), true),
98 // Helper function for SingleShortestPath building a tree of shortest paths to
99 // every final state in the input FST. It takes the input FST and parent values
100 // computed by SingleShortestPath and builds into the output mutable FST the
101 // subtree of ifst that consists only of the best paths to all final states.
102 // This is not normally called by users; see ShortestPath instead.
104 void SingleShortestTree(
105 const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
106 const std::vector<std::pair<typename Arc::StateId, size_t>> &parent) {
107 ofst->DeleteStates();
108 ofst->SetInputSymbols(ifst.InputSymbols());
109 ofst->SetOutputSymbols(ifst.OutputSymbols());
110 ofst->SetStart(ifst.Start());
111 for (StateIterator<Fst<Arc>> siter(ifst); !siter.Done(); siter.Next()) {
113 ofst->SetFinal(siter.Value(), ifst.Final(siter.Value()));
115 for (const auto &pair : parent) {
116 if (pair.first != kNoStateId && pair.second != kNoArc) {
117 ArcIterator<Fst<Arc>> aiter(ifst, pair.first);
118 aiter.Seek(pair.second);
119 ofst->AddArc(pair.first, aiter.Value());
122 if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
124 ShortestPathProperties(ofst->Properties(kFstProperties, false), true),
128 // Shortest-path algorithm. It builds the output mutable FST so that it contains
129 // the shortest path in the input FST; distance returns the shortest distances
130 // from the source state to each state in the input FST, and the options struct
132 // used to specify options such as the queue discipline, the arc filter and
133 // delta. The super_final option is an output parameter indicating the final
134 // state, and the parent argument is used for the storage of the backtrace path
135 // for each state 1 to n, (i.e., the best previous state and the arc that
136 // transition to state n.) The shortest path is the lowest weight path w.r.t.
137 // the natural semiring order. The weights need to be right distributive and
138 // have the path (kPath) property. False is returned if an error is encountered.
140 // This is not normally called by users; see ShortestPath instead (with n = 1).
141 template <class Arc, class Queue, class ArcFilter>
142 bool SingleShortestPath(
143 const Fst<Arc> &ifst, std::vector<typename Arc::Weight> *distance,
144 const ShortestPathOptions<Arc, Queue, ArcFilter> &opts,
145 typename Arc::StateId *f_parent,
146 std::vector<std::pair<typename Arc::StateId, size_t>> *parent) {
147 using StateId = typename Arc::StateId;
148 using Weight = typename Arc::Weight;
150 *f_parent = kNoStateId;
151 if (ifst.Start() == kNoStateId) return true;
152 std::vector<bool> enqueued;
153 auto state_queue = opts.state_queue;
154 const auto source = (opts.source == kNoStateId) ? ifst.Start() : opts.source;
155 bool final_seen = false;
156 auto f_distance = Weight::Zero();
158 state_queue->Clear();
159 if ((Weight::Properties() & (kPath | kRightSemiring)) !=
160 (kPath | kRightSemiring)) {
161 FSTERROR() << "SingleShortestPath: Weight needs to have the path"
162 << " property and be right distributive: " << Weight::Type();
165 while (distance->size() < source) {
166 distance->push_back(Weight::Zero());
167 enqueued.push_back(false);
168 parent->push_back(std::make_pair(kNoStateId, kNoArc));
170 distance->push_back(Weight::One());
171 parent->push_back(std::make_pair(kNoStateId, kNoArc));
172 state_queue->Enqueue(source);
173 enqueued.push_back(true);
174 while (!state_queue->Empty()) {
175 const auto s = state_queue->Head();
176 state_queue->Dequeue();
178 const auto sd = (*distance)[s];
179 // If we are using a shortest queue, no other path is going to be shorter
180 // than f_distance at this point.
181 if (opts.first_path && final_seen && f_distance == Plus(f_distance, sd)) {
184 if (ifst.Final(s) != Weight::Zero()) {
185 const auto plus = Plus(f_distance, Times(sd, ifst.Final(s)));
186 if (f_distance != plus) {
190 if (!f_distance.Member()) return false;
193 for (ArcIterator<Fst<Arc>> aiter(ifst, s); !aiter.Done(); aiter.Next()) {
194 const auto &arc = aiter.Value();
195 while (distance->size() <= arc.nextstate) {
196 distance->push_back(Weight::Zero());
197 enqueued.push_back(false);
198 parent->push_back(std::make_pair(kNoStateId, kNoArc));
200 auto &nd = (*distance)[arc.nextstate];
201 const auto weight = Times(sd, arc.weight);
202 if (nd != Plus(nd, weight)) {
203 nd = Plus(nd, weight);
204 if (!nd.Member()) return false;
205 (*parent)[arc.nextstate] = std::make_pair(s, aiter.Position());
206 if (!enqueued[arc.nextstate]) {
207 state_queue->Enqueue(arc.nextstate);
208 enqueued[arc.nextstate] = true;
210 state_queue->Update(arc.nextstate);
218 template <class StateId, class Weight>
219 class ShortestPathCompare {
221 ShortestPathCompare(const std::vector<std::pair<StateId, Weight>> &pairs,
222 const std::vector<Weight> &distance, StateId superfinal,
226 superfinal_(superfinal),
229 bool operator()(const StateId x, const StateId y) const {
230 const auto &px = pairs_[x];
231 const auto &py = pairs_[y];
232 const auto wx = Times(PWeight(px.first), px.second);
233 const auto wy = Times(PWeight(py.first), py.second);
234 // Penalize complete paths to ensure correct results with inexact weights.
235 // This forms a strict weak order so long as ApproxEqual(a, b) =>
236 // ApproxEqual(a, c) for all c s.t. less_(a, c) && less_(c, b).
237 if (px.first == superfinal_ && py.first != superfinal_) {
238 return less_(wy, wx) || ApproxEqual(wx, wy, delta_);
239 } else if (py.first == superfinal_ && px.first != superfinal_) {
240 return less_(wy, wx) && !ApproxEqual(wx, wy, delta_);
242 return less_(wy, wx);
247 Weight PWeight(StateId state) const {
248 return (state == superfinal_)
250 : (state < distance_.size()) ? distance_[state] : Weight::Zero();
253 const std::vector<std::pair<StateId, Weight>> &pairs_;
254 const std::vector<Weight> &distance_;
255 const StateId superfinal_;
257 NaturalLess<Weight> less_;
260 // N-Shortest-path algorithm: implements the core n-shortest path algorithm.
261 // The output is built reversed. See below for versions with more options and
264 // The output mutable FST contains the REVERSE of n'shortest paths in the input
265 // FST; distance must contain the shortest distance from each state to a final
266 // state in the input FST; delta is the convergence delta.
268 // The n-shortest paths are the n-lowest weight paths w.r.t. the natural
269 // semiring order. The single path that can be read from the ith of at most n
270 // transitions leaving the initial state of the the input FST is the ith
271 // shortest path. Disregarding the initial state and initial transitions, the
272 // n-shortest paths, in fact, form a tree rooted at the single final state.
274 // The weights need to be left and right distributive (kSemiring) and have the
275 // path (kPath) property.
277 // Arc weights must satisfy the property that the sum of the weights of one or
278 // more paths from some state S to T is never Zero(). In particular, arc weights
281 // For more information, see:
283 // Mohri, M, and Riley, M. 2002. An efficient algorithm for the n-best-strings
284 // problem. In Proc. ICSLP.
286 // The algorithm relies on the shortest-distance algorithm. There are some
287 // issues with the pseudo-code as written in the paper (viz., line 11).
289 // IMPLEMENTATION NOTE: The input FST can be a delayed FST and and at any state
290 // in its expansion the values of distance vector need only be defined at that
291 // time for the states that are known to exist.
292 template <class Arc, class RevArc>
293 void NShortestPath(const Fst<RevArc> &ifst, MutableFst<Arc> *ofst,
294 const std::vector<typename Arc::Weight> &distance,
295 int32 nshortest, float delta = kDelta,
296 typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
297 typename Arc::StateId state_threshold = kNoStateId) {
298 using StateId = typename Arc::StateId;
299 using Weight = typename Arc::Weight;
300 using Pair = std::pair<StateId, Weight>;
301 if (nshortest <= 0) return;
302 // TODO(kbg): Make this a compile-time static_assert once we have a pleasant
303 // way to "deregister" this operation for non-path semirings so an informative
304 // error message is produced.
305 if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) {
306 FSTERROR() << "NShortestPath: Weight needs to have the "
307 << "path property and be distributive: " << Weight::Type();
308 ofst->SetProperties(kError, kError);
311 ofst->DeleteStates();
312 ofst->SetInputSymbols(ifst.InputSymbols());
313 ofst->SetOutputSymbols(ifst.OutputSymbols());
314 // Each state in ofst corresponds to a path with weight w from the initial
315 // state of ifst to a state s in ifst, that can be characterized by a pair
316 // (s, w). The vector pairs maps each state in ofst to the corresponding
317 // pair maps states in ofst to the corresponding pair (s, w).
318 std::vector<Pair> pairs;
319 // The supefinal state is denoted by kNoStateId. The distance from the
320 // superfinal state to the final state is semiring One, so
321 // `distance[kNoStateId]` is not needed.
322 const ShortestPathCompare<StateId, Weight> compare(pairs, distance,
324 const NaturalLess<Weight> less;
325 if (ifst.Start() == kNoStateId || distance.size() <= ifst.Start() ||
326 distance[ifst.Start()] == Weight::Zero() ||
327 less(weight_threshold, Weight::One()) || state_threshold == 0) {
328 if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
331 ofst->SetStart(ofst->AddState());
332 const auto final_state = ofst->AddState();
333 ofst->SetFinal(final_state, Weight::One());
334 while (pairs.size() <= final_state) {
335 pairs.push_back(std::make_pair(kNoStateId, Weight::Zero()));
337 pairs[final_state] = std::make_pair(ifst.Start(), Weight::One());
338 std::vector<StateId> heap;
339 heap.push_back(final_state);
340 const auto limit = Times(distance[ifst.Start()], weight_threshold);
341 // r[s + 1], s state in fst, is the number of states in ofst which
342 // corresponding pair contains s, i.e., it is number of paths computed so far
343 // to s. Valid for s == kNoStateId (the superfinal state).
345 while (!heap.empty()) {
346 std::pop_heap(heap.begin(), heap.end(), compare);
347 const auto state = heap.back();
348 const auto p = pairs[state];
351 (p.first == kNoStateId)
353 : (p.first < distance.size()) ? distance[p.first] : Weight::Zero();
354 if (less(limit, Times(d, p.second)) ||
355 (state_threshold != kNoStateId &&
356 ofst->NumStates() >= state_threshold)) {
359 while (r.size() <= p.first + 1) r.push_back(0);
361 if (p.first == kNoStateId) {
362 ofst->AddArc(ofst->Start(), Arc(0, 0, Weight::One(), state));
364 if ((p.first == kNoStateId) && (r[p.first + 1] == nshortest)) break;
365 if (r[p.first + 1] > nshortest) continue;
366 if (p.first == kNoStateId) continue;
367 for (ArcIterator<Fst<RevArc>> aiter(ifst, p.first); !aiter.Done();
369 const auto &rarc = aiter.Value();
370 Arc arc(rarc.ilabel, rarc.olabel, rarc.weight.Reverse(), rarc.nextstate);
371 const auto weight = Times(p.second, arc.weight);
372 const auto next = ofst->AddState();
373 pairs.push_back(std::make_pair(arc.nextstate, weight));
374 arc.nextstate = state;
375 ofst->AddArc(next, arc);
376 heap.push_back(next);
377 std::push_heap(heap.begin(), heap.end(), compare);
379 const auto final_weight = ifst.Final(p.first).Reverse();
380 if (final_weight != Weight::Zero()) {
381 const auto weight = Times(p.second, final_weight);
382 const auto next = ofst->AddState();
383 pairs.push_back(std::make_pair(kNoStateId, weight));
384 ofst->AddArc(next, Arc(0, 0, final_weight, state));
385 heap.push_back(next);
386 std::push_heap(heap.begin(), heap.end(), compare);
390 if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
392 ShortestPathProperties(ofst->Properties(kFstProperties, false)),
396 } // namespace internal
398 // N-Shortest-path algorithm: this version allows finer control via the options
399 // argument. See below for a simpler interface. The output mutable FST contains
400 // the n-shortest paths in the input FST; the distance argument is used to
401 // return the shortest distances from the source state to each state in the
402 // input FST, and the options struct is used to specify the number of paths to
403 // return, whether they need to have distinct input strings, the queue
404 // discipline, the arc filter and the convergence delta.
406 // The n-shortest paths are the n-lowest weight paths w.r.t. the natural
407 // semiring order. The single path that can be read from the ith of at most n
408 // transitions leaving the initial state of the output FST is the ith shortest
410 // Disregarding the initial state and initial transitions, The n-shortest paths,
411 // in fact, form a tree rooted at the single final state.
413 // The weights need to be right distributive and have the path (kPath) property.
414 // They need to be left distributive as well for nshortest > 1.
416 // For more information, see:
418 // Mohri, M, and Riley, M. 2002. An efficient algorithm for the n-best-strings
419 // problem. In Proc. ICSLP.
421 // The algorithm relies on the shortest-distance algorithm. There are some
422 // issues with the pseudo-code as written in the paper (viz., line 11).
423 template <class Arc, class Queue, class ArcFilter>
424 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
425 std::vector<typename Arc::Weight> *distance,
426 const ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
427 using StateId = typename Arc::StateId;
428 using Weight = typename Arc::Weight;
429 using RevArc = ReverseArc<Arc>;
430 if (opts.nshortest == 1) {
431 std::vector<std::pair<StateId, size_t>> parent;
433 if (internal::SingleShortestPath(ifst, distance, opts, &f_parent,
435 internal::SingleShortestPathBacktrace(ifst, ofst, parent, f_parent);
437 ofst->SetProperties(kError, kError);
441 if (opts.nshortest <= 0) return;
442 // TODO(kbg): Make this a compile-time static_assert once we have a pleasant
443 // way to "deregister" this operation for non-path semirings so an informative
444 // error message is produced.
445 if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) {
446 FSTERROR() << "ShortestPath: Weight needs to have the "
447 << "path property and be distributive: " << Weight::Type();
448 ofst->SetProperties(kError, kError);
451 if (!opts.has_distance) {
452 ShortestDistance(ifst, distance, opts);
453 if (distance->size() == 1 && !(*distance)[0].Member()) {
454 ofst->SetProperties(kError, kError);
458 // Algorithm works on the reverse of 'fst'; 'distance' is the distance to the
459 // final state in 'rfst', 'ofst' is built as the reverse of the tree of
460 // n-shortest path in 'rfst'.
461 VectorFst<RevArc> rfst;
462 Reverse(ifst, &rfst);
463 auto d = Weight::Zero();
464 for (ArcIterator<VectorFst<RevArc>> aiter(rfst, 0); !aiter.Done();
466 const auto &arc = aiter.Value();
467 const auto state = arc.nextstate - 1;
468 if (state < distance->size()) {
469 d = Plus(d, Times(arc.weight.Reverse(), (*distance)[state]));
472 // TODO(kbg): Avoid this expensive vector operation.
473 distance->insert(distance->begin(), d);
475 internal::NShortestPath(rfst, ofst, *distance, opts.nshortest, opts.delta,
476 opts.weight_threshold, opts.state_threshold);
478 std::vector<Weight> ddistance;
479 DeterminizeFstOptions<RevArc> dopts(opts.delta);
480 DeterminizeFst<RevArc> dfst(rfst, distance, &ddistance, dopts);
481 internal::NShortestPath(dfst, ofst, ddistance, opts.nshortest, opts.delta,
482 opts.weight_threshold, opts.state_threshold);
484 // TODO(kbg): Avoid this expensive vector operation.
485 distance->erase(distance->begin());
488 // Shortest-path algorithm: simplified interface. See above for a version that
489 // allows finer control. The output mutable FST contains the n-shortest paths
490 // in the input FST. The queue discipline is automatically selected. When unique
491 // is true, only paths with distinct input label sequences are returned.
493 // The n-shortest paths are the n-lowest weight paths w.r.t. the natural
494 // semiring order. The single path that can be read from the ith of at most n
495 // transitions leaving the initial state of the ouput FST is the ith best path.
496 // The weights need to be right distributive and have the path (kPath) property.
498 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
499 int32 nshortest = 1, bool unique = false,
500 bool first_path = false,
501 typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
502 typename Arc::StateId state_threshold = kNoStateId) {
503 using StateId = typename Arc::StateId;
504 std::vector<typename Arc::Weight> distance;
505 AnyArcFilter<Arc> arc_filter;
506 AutoQueue<StateId> state_queue(ifst, &distance, arc_filter);
507 const ShortestPathOptions<Arc, AutoQueue<StateId>, AnyArcFilter<Arc>> opts(
508 &state_queue, arc_filter, nshortest, unique, false, kDelta, first_path,
509 weight_threshold, state_threshold);
510 ShortestPath(ifst, ofst, &distance, opts);
515 #endif // FST_LIB_SHORTEST_PATH_H_