Imported Upstream version 1.6.6
[platform/upstream/openfst.git] / src / include / fst / shortest-path.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions to find shortest paths in an FST.
5
6 #ifndef FST_SHORTEST_PATH_H_
7 #define FST_SHORTEST_PATH_H_
8
9 #include <functional>
10 #include <type_traits>
11 #include <utility>
12 #include <vector>
13
14 #include <fst/log.h>
15
16 #include <fst/cache.h>
17 #include <fst/determinize.h>
18 #include <fst/queue.h>
19 #include <fst/shortest-distance.h>
20 #include <fst/test-properties.h>
21
22
23 namespace fst {
24
25 template <class Arc, class Queue, class ArcFilter>
26 struct ShortestPathOptions
27     : public ShortestDistanceOptions<Arc, Queue, ArcFilter> {
28   using StateId = typename Arc::StateId;
29   using Weight = typename Arc::Weight;
30
31   int32 nshortest;    // Returns n-shortest paths.
32   bool unique;        // Only returns paths with distinct input strings.
33   bool has_distance;  // Distance vector already contains the
34                       // shortest distance from the initial state.
35   bool first_path;    // Single shortest path stops after finding the first
36                       // path to a final state; that path is the shortest path
37                       // only when using the ShortestFirstQueue and
38                       // only when all the weights in the FST are between
39                       // One() and Zero() according to NaturalLess.
40   Weight weight_threshold;  // Pruning weight threshold.
41   StateId state_threshold;  // Pruning state threshold.
42
43   ShortestPathOptions(Queue *queue, ArcFilter filter, int32 nshortest = 1,
44                       bool unique = false, bool has_distance = false,
45                       float delta = kShortestDelta, bool first_path = false,
46                       Weight weight_threshold = Weight::Zero(),
47                       StateId state_threshold = kNoStateId)
48       : ShortestDistanceOptions<Arc, Queue, ArcFilter>(queue, filter,
49                                                        kNoStateId, delta),
50         nshortest(nshortest),
51         unique(unique),
52         has_distance(has_distance),
53         first_path(first_path),
54         weight_threshold(std::move(weight_threshold)),
55         state_threshold(state_threshold) {}
56 };
57
58 namespace internal {
59
60 constexpr size_t kNoArc = -1;
61
62 // Helper function for SingleShortestPath building the shortest path as a left-
63 // to-right machine backwards from the best final state. It takes the input
64 // FST passed to SingleShortestPath and the parent vector and f_parent returned
65 // by that function, and builds the result into the provided output mutable FS
66 // This is not normally called by users; see ShortestPath instead.
67 template <class Arc>
68 void SingleShortestPathBacktrace(
69     const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
70     const std::vector<std::pair<typename Arc::StateId, size_t>> &parent,
71     typename Arc::StateId f_parent) {
72   using StateId = typename Arc::StateId;
73   ofst->DeleteStates();
74   ofst->SetInputSymbols(ifst.InputSymbols());
75   ofst->SetOutputSymbols(ifst.OutputSymbols());
76   StateId s_p = kNoStateId;
77   StateId d_p = kNoStateId;
78   for (StateId state = f_parent, d = kNoStateId; state != kNoStateId;
79        d = state, state = parent[state].first) {
80     d_p = s_p;
81     s_p = ofst->AddState();
82     if (d == kNoStateId) {
83       ofst->SetFinal(s_p, ifst.Final(f_parent));
84     } else {
85       ArcIterator<Fst<Arc>> aiter(ifst, state);
86       aiter.Seek(parent[d].second);
87       auto arc = aiter.Value();
88       arc.nextstate = d_p;
89       ofst->AddArc(s_p, arc);
90     }
91   }
92   ofst->SetStart(s_p);
93   if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
94   ofst->SetProperties(
95       ShortestPathProperties(ofst->Properties(kFstProperties, false), true),
96       kFstProperties);
97 }
98
99 // Helper function for SingleShortestPath building a tree of shortest paths to
100 // every final state in the input FST. It takes the input FST and parent values
101 // computed by SingleShortestPath and builds into the output mutable FST the
102 // subtree of ifst that consists only of the best paths to all final states.
103 // This is not normally called by users; see ShortestPath instead.
104 template <class Arc>
105 void SingleShortestTree(
106     const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
107     const std::vector<std::pair<typename Arc::StateId, size_t>> &parent) {
108   ofst->DeleteStates();
109   ofst->SetInputSymbols(ifst.InputSymbols());
110   ofst->SetOutputSymbols(ifst.OutputSymbols());
111   ofst->SetStart(ifst.Start());
112   for (StateIterator<Fst<Arc>> siter(ifst); !siter.Done(); siter.Next()) {
113     ofst->AddState();
114     ofst->SetFinal(siter.Value(), ifst.Final(siter.Value()));
115   }
116   for (const auto &pair : parent) {
117     if (pair.first != kNoStateId && pair.second != kNoArc) {
118       ArcIterator<Fst<Arc>> aiter(ifst, pair.first);
119       aiter.Seek(pair.second);
120       ofst->AddArc(pair.first, aiter.Value());
121     }
122   }
123   if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
124   ofst->SetProperties(
125       ShortestPathProperties(ofst->Properties(kFstProperties, false), true),
126       kFstProperties);
127 }
128
129 // Shortest-path algorithm. It builds the output mutable FST so that it contains
130 // the shortest path in the input FST; distance returns the shortest distances
131 // from the source state to each state in the input FST, and the options struct
132 // is
133 // used to specify options such as the queue discipline, the arc filter and
134 // delta. The super_final option is an output parameter indicating the final
135 // state, and the parent argument is used for the storage of the backtrace path
136 // for each state 1 to n, (i.e., the best previous state and the arc that
137 // transition to state n.) The shortest path is the lowest weight path w.r.t.
138 // the natural semiring order. The weights need to be right distributive and
139 // have the path (kPath) property. False is returned if an error is encountered.
140 //
141 // This is not normally called by users; see ShortestPath instead (with n = 1).
142 template <class Arc, class Queue, class ArcFilter>
143 bool SingleShortestPath(
144     const Fst<Arc> &ifst, std::vector<typename Arc::Weight> *distance,
145     const ShortestPathOptions<Arc, Queue, ArcFilter> &opts,
146     typename Arc::StateId *f_parent,
147     std::vector<std::pair<typename Arc::StateId, size_t>> *parent) {
148   using StateId = typename Arc::StateId;
149   using Weight = typename Arc::Weight;
150   static_assert((Weight::Properties() & kPath) == kPath,
151                 "Weight must have path property.");
152   static_assert((Weight::Properties() & kRightSemiring) == kRightSemiring,
153                 "Weight must be right distributive.");
154   parent->clear();
155   *f_parent = kNoStateId;
156   if (ifst.Start() == kNoStateId) return true;
157   std::vector<bool> enqueued;
158   auto state_queue = opts.state_queue;
159   const auto source = (opts.source == kNoStateId) ? ifst.Start() : opts.source;
160   bool final_seen = false;
161   auto f_distance = Weight::Zero();
162   distance->clear();
163   state_queue->Clear();
164   while (distance->size() < source) {
165     distance->push_back(Weight::Zero());
166     enqueued.push_back(false);
167     parent->push_back(std::make_pair(kNoStateId, kNoArc));
168   }
169   distance->push_back(Weight::One());
170   parent->push_back(std::make_pair(kNoStateId, kNoArc));
171   state_queue->Enqueue(source);
172   enqueued.push_back(true);
173   while (!state_queue->Empty()) {
174     const auto s = state_queue->Head();
175     state_queue->Dequeue();
176     enqueued[s] = false;
177     const auto sd = (*distance)[s];
178     // If we are using a shortest queue, no other path is going to be shorter
179     // than f_distance at this point.
180     if (opts.first_path && final_seen && f_distance == Plus(f_distance, sd)) {
181       break;
182     }
183     if (ifst.Final(s) != Weight::Zero()) {
184       const auto plus = Plus(f_distance, Times(sd, ifst.Final(s)));
185       if (f_distance != plus) {
186         f_distance = plus;
187         *f_parent = s;
188       }
189       if (!f_distance.Member()) return false;
190       final_seen = true;
191     }
192     for (ArcIterator<Fst<Arc>> aiter(ifst, s); !aiter.Done(); aiter.Next()) {
193       const auto &arc = aiter.Value();
194       while (distance->size() <= arc.nextstate) {
195         distance->push_back(Weight::Zero());
196         enqueued.push_back(false);
197         parent->push_back(std::make_pair(kNoStateId, kNoArc));
198       }
199       auto &nd = (*distance)[arc.nextstate];
200       const auto weight = Times(sd, arc.weight);
201       if (nd != Plus(nd, weight)) {
202         nd = Plus(nd, weight);
203         if (!nd.Member()) return false;
204         (*parent)[arc.nextstate] = std::make_pair(s, aiter.Position());
205         if (!enqueued[arc.nextstate]) {
206           state_queue->Enqueue(arc.nextstate);
207           enqueued[arc.nextstate] = true;
208         } else {
209           state_queue->Update(arc.nextstate);
210         }
211       }
212     }
213   }
214   return true;
215 }
216
217 template <class StateId, class Weight>
218 class ShortestPathCompare {
219  public:
220   ShortestPathCompare(const std::vector<std::pair<StateId, Weight>> &pairs,
221                       const std::vector<Weight> &distance, StateId superfinal,
222                       float delta)
223       : pairs_(pairs),
224         distance_(distance),
225         superfinal_(superfinal),
226         delta_(delta) {}
227
228   bool operator()(const StateId x, const StateId y) const {
229     const auto &px = pairs_[x];
230     const auto &py = pairs_[y];
231     const auto wx = Times(PWeight(px.first), px.second);
232     const auto wy = Times(PWeight(py.first), py.second);
233     // Penalize complete paths to ensure correct results with inexact weights.
234     // This forms a strict weak order so long as ApproxEqual(a, b) =>
235     // ApproxEqual(a, c) for all c s.t. less_(a, c) && less_(c, b).
236     if (px.first == superfinal_ && py.first != superfinal_) {
237       return less_(wy, wx) || ApproxEqual(wx, wy, delta_);
238     } else if (py.first == superfinal_ && px.first != superfinal_) {
239       return less_(wy, wx) && !ApproxEqual(wx, wy, delta_);
240     } else {
241       return less_(wy, wx);
242     }
243   }
244
245  private:
246   Weight PWeight(StateId state) const {
247     return (state == superfinal_)
248                ? Weight::One()
249                : (state < distance_.size()) ? distance_[state] : Weight::Zero();
250   }
251
252   const std::vector<std::pair<StateId, Weight>> &pairs_;
253   const std::vector<Weight> &distance_;
254   const StateId superfinal_;
255   const float delta_;
256   NaturalLess<Weight> less_;
257 };
258
259 // N-Shortest-path algorithm: implements the core n-shortest path algorithm.
260 // The output is built reversed. See below for versions with more options and
261 // *not reversed*.
262 //
263 // The output mutable FST contains the REVERSE of n'shortest paths in the input
264 // FST; distance must contain the shortest distance from each state to a final
265 // state in the input FST; delta is the convergence delta.
266 //
267 // The n-shortest paths are the n-lowest weight paths w.r.t. the natural
268 // semiring order. The single path that can be read from the ith of at most n
269 // transitions leaving the initial state of the the input FST is the ith
270 // shortest path. Disregarding the initial state and initial transitions, the
271 // n-shortest paths, in fact, form a tree rooted at the single final state.
272 //
273 // The weights need to be left and right distributive (kSemiring) and have the
274 // path (kPath) property.
275 //
276 // Arc weights must satisfy the property that the sum of the weights of one or
277 // more paths from some state S to T is never Zero(). In particular, arc weights
278 // are never Zero().
279 //
280 // For more information, see:
281 //
282 // Mohri, M, and Riley, M. 2002. An efficient algorithm for the n-best-strings
283 // problem. In Proc. ICSLP.
284 //
285 // The algorithm relies on the shortest-distance algorithm. There are some
286 // issues with the pseudo-code as written in the paper (viz., line 11).
287 //
288 // IMPLEMENTATION NOTE: The input FST can be a delayed FST and and at any state
289 // in its expansion the values of distance vector need only be defined at that
290 // time for the states that are known to exist.
291 template <class Arc, class RevArc>
292 void NShortestPath(const Fst<RevArc> &ifst, MutableFst<Arc> *ofst,
293                    const std::vector<typename Arc::Weight> &distance,
294                    int32 nshortest, float delta = kShortestDelta,
295                    typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
296                    typename Arc::StateId state_threshold = kNoStateId) {
297   using StateId = typename Arc::StateId;
298   using Weight = typename Arc::Weight;
299   using Pair = std::pair<StateId, Weight>;
300   static_assert((Weight::Properties() & kPath) == kPath,
301                 "Weight must have path property.");
302   static_assert((Weight::Properties() & kSemiring) == kSemiring,
303                 "Weight must be distributive.");
304   if (nshortest <= 0) return;
305   ofst->DeleteStates();
306   ofst->SetInputSymbols(ifst.InputSymbols());
307   ofst->SetOutputSymbols(ifst.OutputSymbols());
308   // Each state in ofst corresponds to a path with weight w from the initial
309   // state of ifst to a state s in ifst, that can be characterized by a pair
310   // (s, w). The vector pairs maps each state in ofst to the corresponding
311   // pair maps states in ofst to the corresponding pair (s, w).
312   std::vector<Pair> pairs;
313   // The supefinal state is denoted by kNoStateId. The distance from the
314   // superfinal state to the final state is semiring One, so
315   // `distance[kNoStateId]` is not needed.
316   const ShortestPathCompare<StateId, Weight> compare(pairs, distance,
317                                                      kNoStateId, delta);
318   const NaturalLess<Weight> less;
319   if (ifst.Start() == kNoStateId || distance.size() <= ifst.Start() ||
320       distance[ifst.Start()] == Weight::Zero() ||
321       less(weight_threshold, Weight::One()) || state_threshold == 0) {
322     if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
323     return;
324   }
325   ofst->SetStart(ofst->AddState());
326   const auto final_state = ofst->AddState();
327   ofst->SetFinal(final_state, Weight::One());
328   while (pairs.size() <= final_state) {
329     pairs.push_back(std::make_pair(kNoStateId, Weight::Zero()));
330   }
331   pairs[final_state] = std::make_pair(ifst.Start(), Weight::One());
332   std::vector<StateId> heap;
333   heap.push_back(final_state);
334   const auto limit = Times(distance[ifst.Start()], weight_threshold);
335   // r[s + 1], s state in fst, is the number of states in ofst which
336   // corresponding pair contains s, i.e., it is number of paths computed so far
337   // to s. Valid for s == kNoStateId (the superfinal state).
338   std::vector<int> r;
339   while (!heap.empty()) {
340     std::pop_heap(heap.begin(), heap.end(), compare);
341     const auto state = heap.back();
342     const auto p = pairs[state];
343     heap.pop_back();
344     const auto d =
345         (p.first == kNoStateId)
346             ? Weight::One()
347             : (p.first < distance.size()) ? distance[p.first] : Weight::Zero();
348     if (less(limit, Times(d, p.second)) ||
349         (state_threshold != kNoStateId &&
350          ofst->NumStates() >= state_threshold)) {
351       continue;
352     }
353     while (r.size() <= p.first + 1) r.push_back(0);
354     ++r[p.first + 1];
355     if (p.first == kNoStateId) {
356       ofst->AddArc(ofst->Start(), Arc(0, 0, Weight::One(), state));
357     }
358     if ((p.first == kNoStateId) && (r[p.first + 1] == nshortest)) break;
359     if (r[p.first + 1] > nshortest) continue;
360     if (p.first == kNoStateId) continue;
361     for (ArcIterator<Fst<RevArc>> aiter(ifst, p.first); !aiter.Done();
362          aiter.Next()) {
363       const auto &rarc = aiter.Value();
364       Arc arc(rarc.ilabel, rarc.olabel, rarc.weight.Reverse(), rarc.nextstate);
365       const auto weight = Times(p.second, arc.weight);
366       const auto next = ofst->AddState();
367       pairs.push_back(std::make_pair(arc.nextstate, weight));
368       arc.nextstate = state;
369       ofst->AddArc(next, arc);
370       heap.push_back(next);
371       std::push_heap(heap.begin(), heap.end(), compare);
372     }
373     const auto final_weight = ifst.Final(p.first).Reverse();
374     if (final_weight != Weight::Zero()) {
375       const auto weight = Times(p.second, final_weight);
376       const auto next = ofst->AddState();
377       pairs.push_back(std::make_pair(kNoStateId, weight));
378       ofst->AddArc(next, Arc(0, 0, final_weight, state));
379       heap.push_back(next);
380       std::push_heap(heap.begin(), heap.end(), compare);
381     }
382   }
383   Connect(ofst);
384   if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
385   ofst->SetProperties(
386       ShortestPathProperties(ofst->Properties(kFstProperties, false)),
387       kFstProperties);
388 }
389
390 }  // namespace internal
391
392 // N-Shortest-path algorithm: this version allows finer control via the options
393 // argument. See below for a simpler interface. The output mutable FST contains
394 // the n-shortest paths in the input FST; the distance argument is used to
395 // return the shortest distances from the source state to each state in the
396 // input FST, and the options struct is used to specify the number of paths to
397 // return, whether they need to have distinct input strings, the queue
398 // discipline, the arc filter and the convergence delta.
399 //
400 // The n-shortest paths are the n-lowest weight paths w.r.t. the natural
401 // semiring order. The single path that can be read from the ith of at most n
402 // transitions leaving the initial state of the output FST is the ith shortest
403 // path.
404 // Disregarding the initial state and initial transitions, The n-shortest paths,
405 // in fact, form a tree rooted at the single final state.
406 //
407 // The weights need to be right distributive and have the path (kPath) property.
408 // They need to be left distributive as well for nshortest > 1.
409 //
410 // For more information, see:
411 //
412 // Mohri, M, and Riley, M. 2002. An efficient algorithm for the n-best-strings
413 // problem. In Proc. ICSLP.
414 //
415 // The algorithm relies on the shortest-distance algorithm. There are some
416 // issues with the pseudo-code as written in the paper (viz., line 11).
417 template <class Arc, class Queue, class ArcFilter,
418           typename std::enable_if<
419               (Arc::Weight::Properties() & (kPath | kSemiring)) ==
420               (kPath | kSemiring)>::type * = nullptr>
421 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
422                   std::vector<typename Arc::Weight> *distance,
423                   const ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
424   using StateId = typename Arc::StateId;
425   using Weight = typename Arc::Weight;
426   using RevArc = ReverseArc<Arc>;
427   if (opts.nshortest == 1) {
428     std::vector<std::pair<StateId, size_t>> parent;
429     StateId f_parent;
430     if (internal::SingleShortestPath(ifst, distance, opts, &f_parent,
431                                      &parent)) {
432       internal::SingleShortestPathBacktrace(ifst, ofst, parent, f_parent);
433     } else {
434       ofst->SetProperties(kError, kError);
435     }
436     return;
437   }
438   if (opts.nshortest <= 0) return;
439   if (!opts.has_distance) {
440     ShortestDistance(ifst, distance, opts);
441     if (distance->size() == 1 && !(*distance)[0].Member()) {
442       ofst->SetProperties(kError, kError);
443       return;
444     }
445   }
446   // Algorithm works on the reverse of 'fst'; 'distance' is the distance to the
447   // final state in 'rfst', 'ofst' is built as the reverse of the tree of
448   // n-shortest path in 'rfst'.
449   VectorFst<RevArc> rfst;
450   Reverse(ifst, &rfst);
451   auto d = Weight::Zero();
452   for (ArcIterator<VectorFst<RevArc>> aiter(rfst, 0); !aiter.Done();
453        aiter.Next()) {
454     const auto &arc = aiter.Value();
455     const auto state = arc.nextstate - 1;
456     if (state < distance->size()) {
457       d = Plus(d, Times(arc.weight.Reverse(), (*distance)[state]));
458     }
459   }
460   // TODO(kbg): Avoid this expensive vector operation.
461   distance->insert(distance->begin(), d);
462   if (!opts.unique) {
463     internal::NShortestPath(rfst, ofst, *distance, opts.nshortest, opts.delta,
464                             opts.weight_threshold, opts.state_threshold);
465   } else {
466     std::vector<Weight> ddistance;
467     DeterminizeFstOptions<RevArc> dopts(opts.delta);
468     DeterminizeFst<RevArc> dfst(rfst, distance, &ddistance, dopts);
469     internal::NShortestPath(dfst, ofst, ddistance, opts.nshortest, opts.delta,
470                             opts.weight_threshold, opts.state_threshold);
471   }
472   // TODO(kbg): Avoid this expensive vector operation.
473   distance->erase(distance->begin());
474 }
475
476 template <class Arc, class Queue, class ArcFilter,
477           typename std::enable_if<
478               (Arc::Weight::Properties() & (kPath | kSemiring)) !=
479               (kPath | kSemiring)>::type * = nullptr>
480 void ShortestPath(const Fst<Arc> &, MutableFst<Arc> *ofst,
481                   std::vector<typename Arc::Weight> *,
482                   const ShortestPathOptions<Arc, Queue, ArcFilter> &) {
483   FSTERROR() << "ShortestPath: Weight needs to have the "
484              << "path property and be distributive: " << Arc::Weight::Type();
485   ofst->SetProperties(kError, kError);
486 }
487
488 // Shortest-path algorithm: simplified interface. See above for a version that
489 // allows finer control. The output mutable FST contains the n-shortest paths
490 // in the input FST. The queue discipline is automatically selected. When unique
491 // is true, only paths with distinct input label sequences are returned.
492 //
493 // The n-shortest paths are the n-lowest weight paths w.r.t. the natural
494 // semiring order. The single path that can be read from the ith of at most n
495 // transitions leaving the initial state of the ouput FST is the ith best path.
496 // The weights need to be right distributive and have the path (kPath) property.
497 template <class Arc>
498 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
499                   int32 nshortest = 1, bool unique = false,
500                   bool first_path = false,
501                   typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
502                   typename Arc::StateId state_threshold = kNoStateId,
503                   float delta = kShortestDelta) {
504   using StateId = typename Arc::StateId;
505   std::vector<typename Arc::Weight> distance;
506   AnyArcFilter<Arc> arc_filter;
507   AutoQueue<StateId> state_queue(ifst, &distance, arc_filter);
508   const ShortestPathOptions<Arc, AutoQueue<StateId>, AnyArcFilter<Arc>> opts(
509       &state_queue, arc_filter, nshortest, unique, false, delta, first_path,
510       weight_threshold, state_threshold);
511   ShortestPath(ifst, ofst, &distance, opts);
512 }
513
514 }  // namespace fst
515
516 #endif  // FST_SHORTEST_PATH_H_