1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Functions to find shortest paths in an FST.
6 #ifndef FST_SHORTEST_PATH_H_
7 #define FST_SHORTEST_PATH_H_
10 #include <type_traits>
16 #include <fst/cache.h>
17 #include <fst/determinize.h>
18 #include <fst/queue.h>
19 #include <fst/shortest-distance.h>
20 #include <fst/test-properties.h>
25 template <class Arc, class Queue, class ArcFilter>
26 struct ShortestPathOptions
27 : public ShortestDistanceOptions<Arc, Queue, ArcFilter> {
28 using StateId = typename Arc::StateId;
29 using Weight = typename Arc::Weight;
31 int32 nshortest; // Returns n-shortest paths.
32 bool unique; // Only returns paths with distinct input strings.
33 bool has_distance; // Distance vector already contains the
34 // shortest distance from the initial state.
35 bool first_path; // Single shortest path stops after finding the first
36 // path to a final state; that path is the shortest path
37 // only when using the ShortestFirstQueue and
38 // only when all the weights in the FST are between
39 // One() and Zero() according to NaturalLess.
40 Weight weight_threshold; // Pruning weight threshold.
41 StateId state_threshold; // Pruning state threshold.
43 ShortestPathOptions(Queue *queue, ArcFilter filter, int32 nshortest = 1,
44 bool unique = false, bool has_distance = false,
45 float delta = kDelta, bool first_path = false,
46 Weight weight_threshold = Weight::Zero(),
47 StateId state_threshold = kNoStateId)
48 : ShortestDistanceOptions<Arc, Queue, ArcFilter>(queue, filter,
52 has_distance(has_distance),
53 first_path(first_path),
54 weight_threshold(std::move(weight_threshold)),
55 state_threshold(state_threshold) {}
60 constexpr size_t kNoArc = -1;
62 // Helper function for SingleShortestPath building the shortest path as a left-
63 // to-right machine backwards from the best final state. It takes the input
64 // FST passed to SingleShortestPath and the parent vector and f_parent returned
65 // by that function, and builds the result into the provided output mutable FS
66 // This is not normally called by users; see ShortestPath instead.
68 void SingleShortestPathBacktrace(
69 const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
70 const std::vector<std::pair<typename Arc::StateId, size_t>> &parent,
71 typename Arc::StateId f_parent) {
72 using StateId = typename Arc::StateId;
74 ofst->SetInputSymbols(ifst.InputSymbols());
75 ofst->SetOutputSymbols(ifst.OutputSymbols());
76 StateId s_p = kNoStateId;
77 StateId d_p = kNoStateId;
78 for (StateId state = f_parent, d = kNoStateId; state != kNoStateId;
79 d = state, state = parent[state].first) {
81 s_p = ofst->AddState();
82 if (d == kNoStateId) {
83 ofst->SetFinal(s_p, ifst.Final(f_parent));
85 ArcIterator<Fst<Arc>> aiter(ifst, state);
86 aiter.Seek(parent[d].second);
87 auto arc = aiter.Value();
89 ofst->AddArc(s_p, arc);
93 if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
95 ShortestPathProperties(ofst->Properties(kFstProperties, false), true),
99 // Helper function for SingleShortestPath building a tree of shortest paths to
100 // every final state in the input FST. It takes the input FST and parent values
101 // computed by SingleShortestPath and builds into the output mutable FST the
102 // subtree of ifst that consists only of the best paths to all final states.
103 // This is not normally called by users; see ShortestPath instead.
105 void SingleShortestTree(
106 const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
107 const std::vector<std::pair<typename Arc::StateId, size_t>> &parent) {
108 ofst->DeleteStates();
109 ofst->SetInputSymbols(ifst.InputSymbols());
110 ofst->SetOutputSymbols(ifst.OutputSymbols());
111 ofst->SetStart(ifst.Start());
112 for (StateIterator<Fst<Arc>> siter(ifst); !siter.Done(); siter.Next()) {
114 ofst->SetFinal(siter.Value(), ifst.Final(siter.Value()));
116 for (const auto &pair : parent) {
117 if (pair.first != kNoStateId && pair.second != kNoArc) {
118 ArcIterator<Fst<Arc>> aiter(ifst, pair.first);
119 aiter.Seek(pair.second);
120 ofst->AddArc(pair.first, aiter.Value());
123 if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
125 ShortestPathProperties(ofst->Properties(kFstProperties, false), true),
129 // Shortest-path algorithm. It builds the output mutable FST so that it contains
130 // the shortest path in the input FST; distance returns the shortest distances
131 // from the source state to each state in the input FST, and the options struct
133 // used to specify options such as the queue discipline, the arc filter and
134 // delta. The super_final option is an output parameter indicating the final
135 // state, and the parent argument is used for the storage of the backtrace path
136 // for each state 1 to n, (i.e., the best previous state and the arc that
137 // transition to state n.) The shortest path is the lowest weight path w.r.t.
138 // the natural semiring order. The weights need to be right distributive and
139 // have the path (kPath) property. False is returned if an error is encountered.
141 // This is not normally called by users; see ShortestPath instead (with n = 1).
142 template <class Arc, class Queue, class ArcFilter>
143 bool SingleShortestPath(
144 const Fst<Arc> &ifst, std::vector<typename Arc::Weight> *distance,
145 const ShortestPathOptions<Arc, Queue, ArcFilter> &opts,
146 typename Arc::StateId *f_parent,
147 std::vector<std::pair<typename Arc::StateId, size_t>> *parent) {
148 using StateId = typename Arc::StateId;
149 using Weight = typename Arc::Weight;
150 static_assert((Weight::Properties() & kPath) == kPath,
151 "Weight must have path property.");
152 static_assert((Weight::Properties() & kRightSemiring) == kRightSemiring,
153 "Weight must be right distributive.");
155 *f_parent = kNoStateId;
156 if (ifst.Start() == kNoStateId) return true;
157 std::vector<bool> enqueued;
158 auto state_queue = opts.state_queue;
159 const auto source = (opts.source == kNoStateId) ? ifst.Start() : opts.source;
160 bool final_seen = false;
161 auto f_distance = Weight::Zero();
163 state_queue->Clear();
164 while (distance->size() < source) {
165 distance->push_back(Weight::Zero());
166 enqueued.push_back(false);
167 parent->push_back(std::make_pair(kNoStateId, kNoArc));
169 distance->push_back(Weight::One());
170 parent->push_back(std::make_pair(kNoStateId, kNoArc));
171 state_queue->Enqueue(source);
172 enqueued.push_back(true);
173 while (!state_queue->Empty()) {
174 const auto s = state_queue->Head();
175 state_queue->Dequeue();
177 const auto sd = (*distance)[s];
178 // If we are using a shortest queue, no other path is going to be shorter
179 // than f_distance at this point.
180 if (opts.first_path && final_seen && f_distance == Plus(f_distance, sd)) {
183 if (ifst.Final(s) != Weight::Zero()) {
184 const auto plus = Plus(f_distance, Times(sd, ifst.Final(s)));
185 if (f_distance != plus) {
189 if (!f_distance.Member()) return false;
192 for (ArcIterator<Fst<Arc>> aiter(ifst, s); !aiter.Done(); aiter.Next()) {
193 const auto &arc = aiter.Value();
194 while (distance->size() <= arc.nextstate) {
195 distance->push_back(Weight::Zero());
196 enqueued.push_back(false);
197 parent->push_back(std::make_pair(kNoStateId, kNoArc));
199 auto &nd = (*distance)[arc.nextstate];
200 const auto weight = Times(sd, arc.weight);
201 if (nd != Plus(nd, weight)) {
202 nd = Plus(nd, weight);
203 if (!nd.Member()) return false;
204 (*parent)[arc.nextstate] = std::make_pair(s, aiter.Position());
205 if (!enqueued[arc.nextstate]) {
206 state_queue->Enqueue(arc.nextstate);
207 enqueued[arc.nextstate] = true;
209 state_queue->Update(arc.nextstate);
217 template <class StateId, class Weight>
218 class ShortestPathCompare {
220 ShortestPathCompare(const std::vector<std::pair<StateId, Weight>> &pairs,
221 const std::vector<Weight> &distance, StateId superfinal,
225 superfinal_(superfinal),
228 bool operator()(const StateId x, const StateId y) const {
229 const auto &px = pairs_[x];
230 const auto &py = pairs_[y];
231 const auto wx = Times(PWeight(px.first), px.second);
232 const auto wy = Times(PWeight(py.first), py.second);
233 // Penalize complete paths to ensure correct results with inexact weights.
234 // This forms a strict weak order so long as ApproxEqual(a, b) =>
235 // ApproxEqual(a, c) for all c s.t. less_(a, c) && less_(c, b).
236 if (px.first == superfinal_ && py.first != superfinal_) {
237 return less_(wy, wx) || ApproxEqual(wx, wy, delta_);
238 } else if (py.first == superfinal_ && px.first != superfinal_) {
239 return less_(wy, wx) && !ApproxEqual(wx, wy, delta_);
241 return less_(wy, wx);
246 Weight PWeight(StateId state) const {
247 return (state == superfinal_)
249 : (state < distance_.size()) ? distance_[state] : Weight::Zero();
252 const std::vector<std::pair<StateId, Weight>> &pairs_;
253 const std::vector<Weight> &distance_;
254 const StateId superfinal_;
256 NaturalLess<Weight> less_;
259 // N-Shortest-path algorithm: implements the core n-shortest path algorithm.
260 // The output is built reversed. See below for versions with more options and
263 // The output mutable FST contains the REVERSE of n'shortest paths in the input
264 // FST; distance must contain the shortest distance from each state to a final
265 // state in the input FST; delta is the convergence delta.
267 // The n-shortest paths are the n-lowest weight paths w.r.t. the natural
268 // semiring order. The single path that can be read from the ith of at most n
269 // transitions leaving the initial state of the the input FST is the ith
270 // shortest path. Disregarding the initial state and initial transitions, the
271 // n-shortest paths, in fact, form a tree rooted at the single final state.
273 // The weights need to be left and right distributive (kSemiring) and have the
274 // path (kPath) property.
276 // Arc weights must satisfy the property that the sum of the weights of one or
277 // more paths from some state S to T is never Zero(). In particular, arc weights
280 // For more information, see:
282 // Mohri, M, and Riley, M. 2002. An efficient algorithm for the n-best-strings
283 // problem. In Proc. ICSLP.
285 // The algorithm relies on the shortest-distance algorithm. There are some
286 // issues with the pseudo-code as written in the paper (viz., line 11).
288 // IMPLEMENTATION NOTE: The input FST can be a delayed FST and and at any state
289 // in its expansion the values of distance vector need only be defined at that
290 // time for the states that are known to exist.
291 template <class Arc, class RevArc>
292 void NShortestPath(const Fst<RevArc> &ifst, MutableFst<Arc> *ofst,
293 const std::vector<typename Arc::Weight> &distance,
294 int32 nshortest, float delta = kDelta,
295 typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
296 typename Arc::StateId state_threshold = kNoStateId) {
297 using StateId = typename Arc::StateId;
298 using Weight = typename Arc::Weight;
299 using Pair = std::pair<StateId, Weight>;
300 static_assert((Weight::Properties() & kPath) == kPath,
301 "Weight must have path property.");
302 static_assert((Weight::Properties() & kSemiring) == kSemiring,
303 "Weight must be distributive.");
304 if (nshortest <= 0) return;
305 ofst->DeleteStates();
306 ofst->SetInputSymbols(ifst.InputSymbols());
307 ofst->SetOutputSymbols(ifst.OutputSymbols());
308 // Each state in ofst corresponds to a path with weight w from the initial
309 // state of ifst to a state s in ifst, that can be characterized by a pair
310 // (s, w). The vector pairs maps each state in ofst to the corresponding
311 // pair maps states in ofst to the corresponding pair (s, w).
312 std::vector<Pair> pairs;
313 // The supefinal state is denoted by kNoStateId. The distance from the
314 // superfinal state to the final state is semiring One, so
315 // `distance[kNoStateId]` is not needed.
316 const ShortestPathCompare<StateId, Weight> compare(pairs, distance,
318 const NaturalLess<Weight> less;
319 if (ifst.Start() == kNoStateId || distance.size() <= ifst.Start() ||
320 distance[ifst.Start()] == Weight::Zero() ||
321 less(weight_threshold, Weight::One()) || state_threshold == 0) {
322 if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
325 ofst->SetStart(ofst->AddState());
326 const auto final_state = ofst->AddState();
327 ofst->SetFinal(final_state, Weight::One());
328 while (pairs.size() <= final_state) {
329 pairs.push_back(std::make_pair(kNoStateId, Weight::Zero()));
331 pairs[final_state] = std::make_pair(ifst.Start(), Weight::One());
332 std::vector<StateId> heap;
333 heap.push_back(final_state);
334 const auto limit = Times(distance[ifst.Start()], weight_threshold);
335 // r[s + 1], s state in fst, is the number of states in ofst which
336 // corresponding pair contains s, i.e., it is number of paths computed so far
337 // to s. Valid for s == kNoStateId (the superfinal state).
339 while (!heap.empty()) {
340 std::pop_heap(heap.begin(), heap.end(), compare);
341 const auto state = heap.back();
342 const auto p = pairs[state];
345 (p.first == kNoStateId)
347 : (p.first < distance.size()) ? distance[p.first] : Weight::Zero();
348 if (less(limit, Times(d, p.second)) ||
349 (state_threshold != kNoStateId &&
350 ofst->NumStates() >= state_threshold)) {
353 while (r.size() <= p.first + 1) r.push_back(0);
355 if (p.first == kNoStateId) {
356 ofst->AddArc(ofst->Start(), Arc(0, 0, Weight::One(), state));
358 if ((p.first == kNoStateId) && (r[p.first + 1] == nshortest)) break;
359 if (r[p.first + 1] > nshortest) continue;
360 if (p.first == kNoStateId) continue;
361 for (ArcIterator<Fst<RevArc>> aiter(ifst, p.first); !aiter.Done();
363 const auto &rarc = aiter.Value();
364 Arc arc(rarc.ilabel, rarc.olabel, rarc.weight.Reverse(), rarc.nextstate);
365 const auto weight = Times(p.second, arc.weight);
366 const auto next = ofst->AddState();
367 pairs.push_back(std::make_pair(arc.nextstate, weight));
368 arc.nextstate = state;
369 ofst->AddArc(next, arc);
370 heap.push_back(next);
371 std::push_heap(heap.begin(), heap.end(), compare);
373 const auto final_weight = ifst.Final(p.first).Reverse();
374 if (final_weight != Weight::Zero()) {
375 const auto weight = Times(p.second, final_weight);
376 const auto next = ofst->AddState();
377 pairs.push_back(std::make_pair(kNoStateId, weight));
378 ofst->AddArc(next, Arc(0, 0, final_weight, state));
379 heap.push_back(next);
380 std::push_heap(heap.begin(), heap.end(), compare);
384 if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
386 ShortestPathProperties(ofst->Properties(kFstProperties, false)),
390 } // namespace internal
392 // N-Shortest-path algorithm: this version allows finer control via the options
393 // argument. See below for a simpler interface. The output mutable FST contains
394 // the n-shortest paths in the input FST; the distance argument is used to
395 // return the shortest distances from the source state to each state in the
396 // input FST, and the options struct is used to specify the number of paths to
397 // return, whether they need to have distinct input strings, the queue
398 // discipline, the arc filter and the convergence delta.
400 // The n-shortest paths are the n-lowest weight paths w.r.t. the natural
401 // semiring order. The single path that can be read from the ith of at most n
402 // transitions leaving the initial state of the output FST is the ith shortest
404 // Disregarding the initial state and initial transitions, The n-shortest paths,
405 // in fact, form a tree rooted at the single final state.
407 // The weights need to be right distributive and have the path (kPath) property.
408 // They need to be left distributive as well for nshortest > 1.
410 // For more information, see:
412 // Mohri, M, and Riley, M. 2002. An efficient algorithm for the n-best-strings
413 // problem. In Proc. ICSLP.
415 // The algorithm relies on the shortest-distance algorithm. There are some
416 // issues with the pseudo-code as written in the paper (viz., line 11).
417 template <class Arc, class Queue, class ArcFilter,
418 typename std::enable_if<
419 (Arc::Weight::Properties() & (kPath | kSemiring)) ==
420 (kPath | kSemiring)>::type * = nullptr>
421 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
422 std::vector<typename Arc::Weight> *distance,
423 const ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
424 using StateId = typename Arc::StateId;
425 using Weight = typename Arc::Weight;
426 using RevArc = ReverseArc<Arc>;
427 if (opts.nshortest == 1) {
428 std::vector<std::pair<StateId, size_t>> parent;
430 if (internal::SingleShortestPath(ifst, distance, opts, &f_parent,
432 internal::SingleShortestPathBacktrace(ifst, ofst, parent, f_parent);
434 ofst->SetProperties(kError, kError);
438 if (opts.nshortest <= 0) return;
439 if (!opts.has_distance) {
440 ShortestDistance(ifst, distance, opts);
441 if (distance->size() == 1 && !(*distance)[0].Member()) {
442 ofst->SetProperties(kError, kError);
446 // Algorithm works on the reverse of 'fst'; 'distance' is the distance to the
447 // final state in 'rfst', 'ofst' is built as the reverse of the tree of
448 // n-shortest path in 'rfst'.
449 VectorFst<RevArc> rfst;
450 Reverse(ifst, &rfst);
451 auto d = Weight::Zero();
452 for (ArcIterator<VectorFst<RevArc>> aiter(rfst, 0); !aiter.Done();
454 const auto &arc = aiter.Value();
455 const auto state = arc.nextstate - 1;
456 if (state < distance->size()) {
457 d = Plus(d, Times(arc.weight.Reverse(), (*distance)[state]));
460 // TODO(kbg): Avoid this expensive vector operation.
461 distance->insert(distance->begin(), d);
463 internal::NShortestPath(rfst, ofst, *distance, opts.nshortest, opts.delta,
464 opts.weight_threshold, opts.state_threshold);
466 std::vector<Weight> ddistance;
467 DeterminizeFstOptions<RevArc> dopts(opts.delta);
468 DeterminizeFst<RevArc> dfst(rfst, distance, &ddistance, dopts);
469 internal::NShortestPath(dfst, ofst, ddistance, opts.nshortest, opts.delta,
470 opts.weight_threshold, opts.state_threshold);
472 // TODO(kbg): Avoid this expensive vector operation.
473 distance->erase(distance->begin());
476 template <class Arc, class Queue, class ArcFilter,
477 typename std::enable_if<
478 (Arc::Weight::Properties() & (kPath | kSemiring)) !=
479 (kPath | kSemiring)>::type * = nullptr>
480 void ShortestPath(const Fst<Arc> &, MutableFst<Arc> *ofst,
481 std::vector<typename Arc::Weight> *,
482 const ShortestPathOptions<Arc, Queue, ArcFilter> &) {
483 FSTERROR() << "ShortestPath: Weight needs to have the "
484 << "path property and be distributive: " << Arc::Weight::Type();
485 ofst->SetProperties(kError, kError);
488 // Shortest-path algorithm: simplified interface. See above for a version that
489 // allows finer control. The output mutable FST contains the n-shortest paths
490 // in the input FST. The queue discipline is automatically selected. When unique
491 // is true, only paths with distinct input label sequences are returned.
493 // The n-shortest paths are the n-lowest weight paths w.r.t. the natural
494 // semiring order. The single path that can be read from the ith of at most n
495 // transitions leaving the initial state of the ouput FST is the ith best path.
496 // The weights need to be right distributive and have the path (kPath) property.
498 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
499 int32 nshortest = 1, bool unique = false,
500 bool first_path = false,
501 typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
502 typename Arc::StateId state_threshold = kNoStateId) {
503 using StateId = typename Arc::StateId;
504 std::vector<typename Arc::Weight> distance;
505 AnyArcFilter<Arc> arc_filter;
506 AutoQueue<StateId> state_queue(ifst, &distance, arc_filter);
507 const ShortestPathOptions<Arc, AutoQueue<StateId>, AnyArcFilter<Arc>> opts(
508 &state_queue, arc_filter, nshortest, unique, false, kDelta, first_path,
509 weight_threshold, state_threshold);
510 ShortestPath(ifst, ofst, &distance, opts);
515 #endif // FST_SHORTEST_PATH_H_