Imported Upstream version 1.6.4
[platform/upstream/openfst.git] / src / include / fst / shortest-path.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions to find shortest paths in an FST.
5
6 #ifndef FST_LIB_SHORTEST_PATH_H_
7 #define FST_LIB_SHORTEST_PATH_H_
8
9 #include <functional>
10 #include <utility>
11 #include <vector>
12
13 #include <fst/log.h>
14
15 #include <fst/cache.h>
16 #include <fst/determinize.h>
17 #include <fst/queue.h>
18 #include <fst/shortest-distance.h>
19 #include <fst/test-properties.h>
20
21
22 namespace fst {
23
24 template <class Arc, class Queue, class ArcFilter>
25 struct ShortestPathOptions
26     : public ShortestDistanceOptions<Arc, Queue, ArcFilter> {
27   using StateId = typename Arc::StateId;
28   using Weight = typename Arc::Weight;
29
30   int32 nshortest;    // Returns n-shortest paths.
31   bool unique;        // Only returns paths with distinct input strings.
32   bool has_distance;  // Distance vector already contains the
33                       // shortest distance from the initial state.
34   bool first_path;    // Single shortest path stops after finding the first
35                       // path to a final state; that path is the shortest path
36                       // only when using the ShortestFirstQueue and
37                       // only when all the weights in the FST are between
38                       // One() and Zero() according to NaturalLess.
39   Weight weight_threshold;  // Pruning weight threshold.
40   StateId state_threshold;  // Pruning state threshold.
41
42   ShortestPathOptions(Queue *queue, ArcFilter filter, int32 nshortest = 1,
43                       bool unique = false, bool has_distance = false,
44                       float delta = kDelta, bool first_path = false,
45                       Weight weight_threshold = Weight::Zero(),
46                       StateId state_threshold = kNoStateId)
47       : ShortestDistanceOptions<Arc, Queue, ArcFilter>(queue, filter,
48                                                        kNoStateId, delta),
49         nshortest(nshortest),
50         unique(unique),
51         has_distance(has_distance),
52         first_path(first_path),
53         weight_threshold(std::move(weight_threshold)),
54         state_threshold(state_threshold) {}
55 };
56
57 namespace internal {
58
59 constexpr size_t kNoArc = -1;
60
61 // Helper function for SingleShortestPath building the shortest path as a left-
62 // to-right machine backwards from the best final state. It takes the input
63 // FST passed to SingleShortestPath and the parent vector and f_parent returned
64 // by that function, and builds the result into the provided output mutable FS
65 // This is not normally called by users; see ShortestPath instead.
66 template <class Arc>
67 void SingleShortestPathBacktrace(
68     const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
69     const std::vector<std::pair<typename Arc::StateId, size_t>> &parent,
70     typename Arc::StateId f_parent) {
71   using StateId = typename Arc::StateId;
72   ofst->DeleteStates();
73   ofst->SetInputSymbols(ifst.InputSymbols());
74   ofst->SetOutputSymbols(ifst.OutputSymbols());
75   StateId s_p = kNoStateId;
76   StateId d_p = kNoStateId;
77   for (StateId state = f_parent, d = kNoStateId; state != kNoStateId;
78        d = state, state = parent[state].first) {
79     d_p = s_p;
80     s_p = ofst->AddState();
81     if (d == kNoStateId) {
82       ofst->SetFinal(s_p, ifst.Final(f_parent));
83     } else {
84       ArcIterator<Fst<Arc>> aiter(ifst, state);
85       aiter.Seek(parent[d].second);
86       auto arc = aiter.Value();
87       arc.nextstate = d_p;
88       ofst->AddArc(s_p, arc);
89     }
90   }
91   ofst->SetStart(s_p);
92   if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
93   ofst->SetProperties(
94       ShortestPathProperties(ofst->Properties(kFstProperties, false), true),
95       kFstProperties);
96 }
97
98 // Helper function for SingleShortestPath building a tree of shortest paths to
99 // every final state in the input FST. It takes the input FST and parent values
100 // computed by SingleShortestPath and builds into the output mutable FST the
101 // subtree of ifst that consists only of the best paths to all final states.
102 // This is not normally called by users; see ShortestPath instead.
103 template <class Arc>
104 void SingleShortestTree(
105     const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
106     const std::vector<std::pair<typename Arc::StateId, size_t>> &parent) {
107   ofst->DeleteStates();
108   ofst->SetInputSymbols(ifst.InputSymbols());
109   ofst->SetOutputSymbols(ifst.OutputSymbols());
110   ofst->SetStart(ifst.Start());
111   for (StateIterator<Fst<Arc>> siter(ifst); !siter.Done(); siter.Next()) {
112     ofst->AddState();
113     ofst->SetFinal(siter.Value(), ifst.Final(siter.Value()));
114   }
115   for (const auto &pair : parent) {
116     if (pair.first != kNoStateId && pair.second != kNoArc) {
117       ArcIterator<Fst<Arc>> aiter(ifst, pair.first);
118       aiter.Seek(pair.second);
119       ofst->AddArc(pair.first, aiter.Value());
120     }
121   }
122   if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
123   ofst->SetProperties(
124       ShortestPathProperties(ofst->Properties(kFstProperties, false), true),
125       kFstProperties);
126 }
127
128 // Shortest-path algorithm. It builds the output mutable FST so that it contains
129 // the shortest path in the input FST; distance returns the shortest distances
130 // from the source state to each state in the input FST, and the options struct
131 // is
132 // used to specify options such as the queue discipline, the arc filter and
133 // delta. The super_final option is an output parameter indicating the final
134 // state, and the parent argument is used for the storage of the backtrace path
135 // for each state 1 to n, (i.e., the best previous state and the arc that
136 // transition to state n.) The shortest path is the lowest weight path w.r.t.
137 // the natural semiring order. The weights need to be right distributive and
138 // have the path (kPath) property. False is returned if an error is encountered.
139 //
140 // This is not normally called by users; see ShortestPath instead (with n = 1).
141 template <class Arc, class Queue, class ArcFilter>
142 bool SingleShortestPath(
143     const Fst<Arc> &ifst, std::vector<typename Arc::Weight> *distance,
144     const ShortestPathOptions<Arc, Queue, ArcFilter> &opts,
145     typename Arc::StateId *f_parent,
146     std::vector<std::pair<typename Arc::StateId, size_t>> *parent) {
147   using StateId = typename Arc::StateId;
148   using Weight = typename Arc::Weight;
149   parent->clear();
150   *f_parent = kNoStateId;
151   if (ifst.Start() == kNoStateId) return true;
152   std::vector<bool> enqueued;
153   auto state_queue = opts.state_queue;
154   const auto source = (opts.source == kNoStateId) ? ifst.Start() : opts.source;
155   bool final_seen = false;
156   auto f_distance = Weight::Zero();
157   distance->clear();
158   state_queue->Clear();
159   if ((Weight::Properties() & (kPath | kRightSemiring)) !=
160       (kPath | kRightSemiring)) {
161     FSTERROR() << "SingleShortestPath: Weight needs to have the path"
162                << " property and be right distributive: " << Weight::Type();
163     return false;
164   }
165   while (distance->size() < source) {
166     distance->push_back(Weight::Zero());
167     enqueued.push_back(false);
168     parent->push_back(std::make_pair(kNoStateId, kNoArc));
169   }
170   distance->push_back(Weight::One());
171   parent->push_back(std::make_pair(kNoStateId, kNoArc));
172   state_queue->Enqueue(source);
173   enqueued.push_back(true);
174   while (!state_queue->Empty()) {
175     const auto s = state_queue->Head();
176     state_queue->Dequeue();
177     enqueued[s] = false;
178     const auto sd = (*distance)[s];
179     // If we are using a shortest queue, no other path is going to be shorter
180     // than f_distance at this point.
181     if (opts.first_path && final_seen && f_distance == Plus(f_distance, sd)) {
182       break;
183     }
184     if (ifst.Final(s) != Weight::Zero()) {
185       const auto plus = Plus(f_distance, Times(sd, ifst.Final(s)));
186       if (f_distance != plus) {
187         f_distance = plus;
188         *f_parent = s;
189       }
190       if (!f_distance.Member()) return false;
191       final_seen = true;
192     }
193     for (ArcIterator<Fst<Arc>> aiter(ifst, s); !aiter.Done(); aiter.Next()) {
194       const auto &arc = aiter.Value();
195       while (distance->size() <= arc.nextstate) {
196         distance->push_back(Weight::Zero());
197         enqueued.push_back(false);
198         parent->push_back(std::make_pair(kNoStateId, kNoArc));
199       }
200       auto &nd = (*distance)[arc.nextstate];
201       const auto weight = Times(sd, arc.weight);
202       if (nd != Plus(nd, weight)) {
203         nd = Plus(nd, weight);
204         if (!nd.Member()) return false;
205         (*parent)[arc.nextstate] = std::make_pair(s, aiter.Position());
206         if (!enqueued[arc.nextstate]) {
207           state_queue->Enqueue(arc.nextstate);
208           enqueued[arc.nextstate] = true;
209         } else {
210           state_queue->Update(arc.nextstate);
211         }
212       }
213     }
214   }
215   return true;
216 }
217
218 template <class StateId, class Weight>
219 class ShortestPathCompare {
220  public:
221   ShortestPathCompare(const std::vector<std::pair<StateId, Weight>> &pairs,
222                       const std::vector<Weight> &distance, StateId superfinal,
223                       float delta)
224       : pairs_(pairs),
225         distance_(distance),
226         superfinal_(superfinal),
227         delta_(delta) {}
228
229   bool operator()(const StateId x, const StateId y) const {
230     const auto &px = pairs_[x];
231     const auto &py = pairs_[y];
232     const auto wx = Times(PWeight(px.first), px.second);
233     const auto wy = Times(PWeight(py.first), py.second);
234     // Penalize complete paths to ensure correct results with inexact weights.
235     // This forms a strict weak order so long as ApproxEqual(a, b) =>
236     // ApproxEqual(a, c) for all c s.t. less_(a, c) && less_(c, b).
237     if (px.first == superfinal_ && py.first != superfinal_) {
238       return less_(wy, wx) || ApproxEqual(wx, wy, delta_);
239     } else if (py.first == superfinal_ && px.first != superfinal_) {
240       return less_(wy, wx) && !ApproxEqual(wx, wy, delta_);
241     } else {
242       return less_(wy, wx);
243     }
244   }
245
246  private:
247   Weight PWeight(StateId state) const {
248     return (state == superfinal_)
249                ? Weight::One()
250                : (state < distance_.size()) ? distance_[state] : Weight::Zero();
251   }
252
253   const std::vector<std::pair<StateId, Weight>> &pairs_;
254   const std::vector<Weight> &distance_;
255   const StateId superfinal_;
256   const float delta_;
257   NaturalLess<Weight> less_;
258 };
259
260 // N-Shortest-path algorithm: implements the core n-shortest path algorithm.
261 // The output is built reversed. See below for versions with more options and
262 // *not reversed*.
263 //
264 // The output mutable FST contains the REVERSE of n'shortest paths in the input
265 // FST; distance must contain the shortest distance from each state to a final
266 // state in the input FST; delta is the convergence delta.
267 //
268 // The n-shortest paths are the n-lowest weight paths w.r.t. the natural
269 // semiring order. The single path that can be read from the ith of at most n
270 // transitions leaving the initial state of the the input FST is the ith
271 // shortest path. Disregarding the initial state and initial transitions, the
272 // n-shortest paths, in fact, form a tree rooted at the single final state.
273 //
274 // The weights need to be left and right distributive (kSemiring) and have the
275 // path (kPath) property.
276 //
277 // Arc weights must satisfy the property that the sum of the weights of one or
278 // more paths from some state S to T is never Zero(). In particular, arc weights
279 // are never Zero().
280 //
281 // For more information, see:
282 //
283 // Mohri, M, and Riley, M. 2002. An efficient algorithm for the n-best-strings
284 // problem. In Proc. ICSLP.
285 //
286 // The algorithm relies on the shortest-distance algorithm. There are some
287 // issues with the pseudo-code as written in the paper (viz., line 11).
288 //
289 // IMPLEMENTATION NOTE: The input FST can be a delayed FST and and at any state
290 // in its expansion the values of distance vector need only be defined at that
291 // time for the states that are known to exist.
292 template <class Arc, class RevArc>
293 void NShortestPath(const Fst<RevArc> &ifst, MutableFst<Arc> *ofst,
294                    const std::vector<typename Arc::Weight> &distance,
295                    int32 nshortest, float delta = kDelta,
296                    typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
297                    typename Arc::StateId state_threshold = kNoStateId) {
298   using StateId = typename Arc::StateId;
299   using Weight = typename Arc::Weight;
300   using Pair = std::pair<StateId, Weight>;
301   if (nshortest <= 0) return;
302   // TODO(kbg): Make this a compile-time static_assert once we have a pleasant
303   // way to "deregister" this operation for non-path semirings so an informative
304   // error message is produced.
305   if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) {
306     FSTERROR() << "NShortestPath: Weight needs to have the "
307                << "path property and be distributive: " << Weight::Type();
308     ofst->SetProperties(kError, kError);
309     return;
310   }
311   ofst->DeleteStates();
312   ofst->SetInputSymbols(ifst.InputSymbols());
313   ofst->SetOutputSymbols(ifst.OutputSymbols());
314   // Each state in ofst corresponds to a path with weight w from the initial
315   // state of ifst to a state s in ifst, that can be characterized by a pair
316   // (s, w). The vector pairs maps each state in ofst to the corresponding
317   // pair maps states in ofst to the corresponding pair (s, w).
318   std::vector<Pair> pairs;
319   // The supefinal state is denoted by kNoStateId. The distance from the
320   // superfinal state to the final state is semiring One, so
321   // `distance[kNoStateId]` is not needed.
322   const ShortestPathCompare<StateId, Weight> compare(pairs, distance,
323                                                      kNoStateId, delta);
324   const NaturalLess<Weight> less;
325   if (ifst.Start() == kNoStateId || distance.size() <= ifst.Start() ||
326       distance[ifst.Start()] == Weight::Zero() ||
327       less(weight_threshold, Weight::One()) || state_threshold == 0) {
328     if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
329     return;
330   }
331   ofst->SetStart(ofst->AddState());
332   const auto final_state = ofst->AddState();
333   ofst->SetFinal(final_state, Weight::One());
334   while (pairs.size() <= final_state) {
335     pairs.push_back(std::make_pair(kNoStateId, Weight::Zero()));
336   }
337   pairs[final_state] = std::make_pair(ifst.Start(), Weight::One());
338   std::vector<StateId> heap;
339   heap.push_back(final_state);
340   const auto limit = Times(distance[ifst.Start()], weight_threshold);
341   // r[s + 1], s state in fst, is the number of states in ofst which
342   // corresponding pair contains s, i.e., it is number of paths computed so far
343   // to s. Valid for s == kNoStateId (the superfinal state).
344   std::vector<int> r;
345   while (!heap.empty()) {
346     std::pop_heap(heap.begin(), heap.end(), compare);
347     const auto state = heap.back();
348     const auto p = pairs[state];
349     heap.pop_back();
350     const auto d =
351         (p.first == kNoStateId)
352             ? Weight::One()
353             : (p.first < distance.size()) ? distance[p.first] : Weight::Zero();
354     if (less(limit, Times(d, p.second)) ||
355         (state_threshold != kNoStateId &&
356          ofst->NumStates() >= state_threshold)) {
357       continue;
358     }
359     while (r.size() <= p.first + 1) r.push_back(0);
360     ++r[p.first + 1];
361     if (p.first == kNoStateId) {
362       ofst->AddArc(ofst->Start(), Arc(0, 0, Weight::One(), state));
363     }
364     if ((p.first == kNoStateId) && (r[p.first + 1] == nshortest)) break;
365     if (r[p.first + 1] > nshortest) continue;
366     if (p.first == kNoStateId) continue;
367     for (ArcIterator<Fst<RevArc>> aiter(ifst, p.first); !aiter.Done();
368          aiter.Next()) {
369       const auto &rarc = aiter.Value();
370       Arc arc(rarc.ilabel, rarc.olabel, rarc.weight.Reverse(), rarc.nextstate);
371       const auto weight = Times(p.second, arc.weight);
372       const auto next = ofst->AddState();
373       pairs.push_back(std::make_pair(arc.nextstate, weight));
374       arc.nextstate = state;
375       ofst->AddArc(next, arc);
376       heap.push_back(next);
377       std::push_heap(heap.begin(), heap.end(), compare);
378     }
379     const auto final_weight = ifst.Final(p.first).Reverse();
380     if (final_weight != Weight::Zero()) {
381       const auto weight = Times(p.second, final_weight);
382       const auto next = ofst->AddState();
383       pairs.push_back(std::make_pair(kNoStateId, weight));
384       ofst->AddArc(next, Arc(0, 0, final_weight, state));
385       heap.push_back(next);
386       std::push_heap(heap.begin(), heap.end(), compare);
387     }
388   }
389   Connect(ofst);
390   if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
391   ofst->SetProperties(
392       ShortestPathProperties(ofst->Properties(kFstProperties, false)),
393       kFstProperties);
394 }
395
396 }  // namespace internal
397
398 // N-Shortest-path algorithm: this version allows finer control via the options
399 // argument. See below for a simpler interface. The output mutable FST contains
400 // the n-shortest paths in the input FST; the distance argument is used to
401 // return the shortest distances from the source state to each state in the
402 // input FST, and the options struct is used to specify the number of paths to
403 // return, whether they need to have distinct input strings, the queue
404 // discipline, the arc filter and the convergence delta.
405 //
406 // The n-shortest paths are the n-lowest weight paths w.r.t. the natural
407 // semiring order. The single path that can be read from the ith of at most n
408 // transitions leaving the initial state of the output FST is the ith shortest
409 // path.
410 // Disregarding the initial state and initial transitions, The n-shortest paths,
411 // in fact, form a tree rooted at the single final state.
412 //
413 // The weights need to be right distributive and have the path (kPath) property.
414 // They need to be left distributive as well for nshortest > 1.
415 //
416 // For more information, see:
417 //
418 // Mohri, M, and Riley, M. 2002. An efficient algorithm for the n-best-strings
419 // problem. In Proc. ICSLP.
420 //
421 // The algorithm relies on the shortest-distance algorithm. There are some
422 // issues with the pseudo-code as written in the paper (viz., line 11).
423 template <class Arc, class Queue, class ArcFilter>
424 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
425                   std::vector<typename Arc::Weight> *distance,
426                   const ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
427   using StateId = typename Arc::StateId;
428   using Weight = typename Arc::Weight;
429   using RevArc = ReverseArc<Arc>;
430   if (opts.nshortest == 1) {
431     std::vector<std::pair<StateId, size_t>> parent;
432     StateId f_parent;
433     if (internal::SingleShortestPath(ifst, distance, opts, &f_parent,
434                                      &parent)) {
435       internal::SingleShortestPathBacktrace(ifst, ofst, parent, f_parent);
436     } else {
437       ofst->SetProperties(kError, kError);
438     }
439     return;
440   }
441   if (opts.nshortest <= 0) return;
442   // TODO(kbg): Make this a compile-time static_assert once we have a pleasant
443   // way to "deregister" this operation for non-path semirings so an informative
444   // error message is produced.
445   if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) {
446     FSTERROR() << "ShortestPath: Weight needs to have the "
447                << "path property and be distributive: " << Weight::Type();
448     ofst->SetProperties(kError, kError);
449     return;
450   }
451   if (!opts.has_distance) {
452     ShortestDistance(ifst, distance, opts);
453     if (distance->size() == 1 && !(*distance)[0].Member()) {
454       ofst->SetProperties(kError, kError);
455       return;
456     }
457   }
458   // Algorithm works on the reverse of 'fst'; 'distance' is the distance to the
459   // final state in 'rfst', 'ofst' is built as the reverse of the tree of
460   // n-shortest path in 'rfst'.
461   VectorFst<RevArc> rfst;
462   Reverse(ifst, &rfst);
463   auto d = Weight::Zero();
464   for (ArcIterator<VectorFst<RevArc>> aiter(rfst, 0); !aiter.Done();
465        aiter.Next()) {
466     const auto &arc = aiter.Value();
467     const auto state = arc.nextstate - 1;
468     if (state < distance->size()) {
469       d = Plus(d, Times(arc.weight.Reverse(), (*distance)[state]));
470     }
471   }
472   // TODO(kbg): Avoid this expensive vector operation.
473   distance->insert(distance->begin(), d);
474   if (!opts.unique) {
475     internal::NShortestPath(rfst, ofst, *distance, opts.nshortest, opts.delta,
476                             opts.weight_threshold, opts.state_threshold);
477   } else {
478     std::vector<Weight> ddistance;
479     DeterminizeFstOptions<RevArc> dopts(opts.delta);
480     DeterminizeFst<RevArc> dfst(rfst, distance, &ddistance, dopts);
481     internal::NShortestPath(dfst, ofst, ddistance, opts.nshortest, opts.delta,
482                             opts.weight_threshold, opts.state_threshold);
483   }
484   // TODO(kbg): Avoid this expensive vector operation.
485   distance->erase(distance->begin());
486 }
487
488 // Shortest-path algorithm: simplified interface. See above for a version that
489 // allows finer control. The output mutable FST contains the n-shortest paths
490 // in the input FST. The queue discipline is automatically selected. When unique
491 // is true, only paths with distinct input label sequences are returned.
492 //
493 // The n-shortest paths are the n-lowest weight paths w.r.t. the natural
494 // semiring order. The single path that can be read from the ith of at most n
495 // transitions leaving the initial state of the ouput FST is the ith best path.
496 // The weights need to be right distributive and have the path (kPath) property.
497 template <class Arc>
498 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
499                   int32 nshortest = 1, bool unique = false,
500                   bool first_path = false,
501                   typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
502                   typename Arc::StateId state_threshold = kNoStateId) {
503   using StateId = typename Arc::StateId;
504   std::vector<typename Arc::Weight> distance;
505   AnyArcFilter<Arc> arc_filter;
506   AutoQueue<StateId> state_queue(ifst, &distance, arc_filter);
507   const ShortestPathOptions<Arc, AutoQueue<StateId>, AnyArcFilter<Arc>> opts(
508       &state_queue, arc_filter, nshortest, unique, false, kDelta, first_path,
509       weight_threshold, state_threshold);
510   ShortestPath(ifst, ofst, &distance, opts);
511 }
512
513 }  // namespace fst
514
515 #endif  // FST_LIB_SHORTEST_PATH_H_