Imported Upstream version 1.6.6
[platform/upstream/openfst.git] / src / include / fst / shortest-distance.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions and classes to find shortest distance in an FST.
5
6 #ifndef FST_SHORTEST_DISTANCE_H_
7 #define FST_SHORTEST_DISTANCE_H_
8
9 #include <deque>
10 #include <vector>
11
12 #include <fst/log.h>
13
14 #include <fst/arcfilter.h>
15 #include <fst/cache.h>
16 #include <fst/queue.h>
17 #include <fst/reverse.h>
18 #include <fst/test-properties.h>
19
20
21 namespace fst {
22
23 // A representable float for shortest distance and shortest path algorithms.
24 constexpr float kShortestDelta = 1e-6;
25
26 template <class Arc, class Queue, class ArcFilter>
27 struct ShortestDistanceOptions {
28   using StateId = typename Arc::StateId;
29
30   Queue *state_queue;    // Queue discipline used; owned by caller.
31   ArcFilter arc_filter;  // Arc filter (e.g., limit to only epsilon graph).
32   StateId source;        // If kNoStateId, use the FST's initial state.
33   float delta;           // Determines the degree of convergence required
34   bool first_path;       // For a semiring with the path property (o.w.
35                          // undefined), compute the shortest-distances along
36                          // along the first path to a final state found
37                          // by the algorithm. That path is the shortest-path
38                          // only if the FST has a unique final state (or all
39                          // the final states have the same final weight), the
40                          // queue discipline is shortest-first and all the
41                          // weights in the FST are between One() and Zero()
42                          // according to NaturalLess.
43
44   ShortestDistanceOptions(Queue *state_queue, ArcFilter arc_filter,
45                           StateId source = kNoStateId,
46                           float delta = kShortestDelta)
47       : state_queue(state_queue),
48         arc_filter(arc_filter),
49         source(source),
50         delta(delta),
51         first_path(false) {}
52 };
53
54 namespace internal {
55
56 // Computation state of the shortest-distance algorithm. Reusable information
57 // is maintained across calls to member function ShortestDistance(source) when
58 // retain is true for improved efficiency when calling multiple times from
59 // different source states (e.g., in epsilon removal). Contrary to the usual
60 // conventions, fst may not be freed before this class. Vector distance
61 // should not be modified by the user between these calls. The Error() method
62 // returns true iff an error was encountered.
63 template <class Arc, class Queue, class ArcFilter>
64 class ShortestDistanceState {
65  public:
66   using StateId = typename Arc::StateId;
67   using Weight = typename Arc::Weight;
68
69   ShortestDistanceState(
70       const Fst<Arc> &fst, std::vector<Weight> *distance,
71       const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts, bool retain)
72       : fst_(fst),
73         distance_(distance),
74         state_queue_(opts.state_queue),
75         arc_filter_(opts.arc_filter),
76         delta_(opts.delta),
77         first_path_(opts.first_path),
78         retain_(retain),
79         source_id_(0),
80         error_(false) {
81     distance_->clear();
82   }
83
84   void ShortestDistance(StateId source);
85
86   bool Error() const { return error_; }
87
88  private:
89   const Fst<Arc> &fst_;
90   std::vector<Weight> *distance_;
91   Queue *state_queue_;
92   ArcFilter arc_filter_;
93   const float delta_;
94   const bool first_path_;
95   const bool retain_;  // Retain and reuse information across calls.
96
97   std::vector<Adder<Weight>> adder_;   // Sums distance_ accurately.
98   std::vector<Adder<Weight>> radder_;  // Relaxation distance.
99   std::vector<bool> enqueued_;         // Is state enqueued?
100   std::vector<StateId> sources_;       // Source ID for ith state in distance_,
101                                        // (r)adder_, and enqueued_ if retained.
102   StateId source_id_;                  // Unique ID characterizing each call.
103   bool error_;
104 };
105
106 // Compute the shortest distance; if source is kNoStateId, uses the initial
107 // state of the FST.
108 template <class Arc, class Queue, class ArcFilter>
109 void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance(
110     StateId source) {
111   if (fst_.Start() == kNoStateId) {
112     if (fst_.Properties(kError, false)) error_ = true;
113     return;
114   }
115   if (!(Weight::Properties() & kRightSemiring)) {
116     FSTERROR() << "ShortestDistance: Weight needs to be right distributive: "
117                << Weight::Type();
118     error_ = true;
119     return;
120   }
121   if (first_path_ && !(Weight::Properties() & kPath)) {
122     FSTERROR() << "ShortestDistance: The first_path option is disallowed when "
123                << "Weight does not have the path property: " << Weight::Type();
124     error_ = true;
125     return;
126   }
127   state_queue_->Clear();
128   if (!retain_) {
129     distance_->clear();
130     adder_.clear();
131     radder_.clear();
132     enqueued_.clear();
133   }
134   if (source == kNoStateId) source = fst_.Start();
135   while (distance_->size() <= source) {
136     distance_->push_back(Weight::Zero());
137     adder_.push_back(Adder<Weight>());
138     radder_.push_back(Adder<Weight>());
139     enqueued_.push_back(false);
140   }
141   if (retain_) {
142     while (sources_.size() <= source) sources_.push_back(kNoStateId);
143     sources_[source] = source_id_;
144   }
145   (*distance_)[source] = Weight::One();
146   adder_[source].Reset(Weight::One());
147   radder_[source].Reset(Weight::One());
148   enqueued_[source] = true;
149   state_queue_->Enqueue(source);
150   while (!state_queue_->Empty()) {
151     const auto state = state_queue_->Head();
152     state_queue_->Dequeue();
153     while (distance_->size() <= state) {
154       distance_->push_back(Weight::Zero());
155       adder_.push_back(Adder<Weight>());
156       radder_.push_back(Adder<Weight>());
157       enqueued_.push_back(false);
158     }
159     if (first_path_ && (fst_.Final(state) != Weight::Zero())) break;
160     enqueued_[state] = false;
161     const auto r = radder_[state].Sum();
162     radder_[state].Reset();
163     for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
164          aiter.Next()) {
165       const auto &arc = aiter.Value();
166       if (!arc_filter_(arc)) continue;
167       while (distance_->size() <= arc.nextstate) {
168         distance_->push_back(Weight::Zero());
169         adder_.push_back(Adder<Weight>());
170         radder_.push_back(Adder<Weight>());
171         enqueued_.push_back(false);
172       }
173       if (retain_) {
174         while (sources_.size() <= arc.nextstate) sources_.push_back(kNoStateId);
175         if (sources_[arc.nextstate] != source_id_) {
176           (*distance_)[arc.nextstate] = Weight::Zero();
177           adder_[arc.nextstate].Reset();
178           radder_[arc.nextstate].Reset();
179           enqueued_[arc.nextstate] = false;
180           sources_[arc.nextstate] = source_id_;
181         }
182       }
183       auto &nd = (*distance_)[arc.nextstate];
184       auto &na = adder_[arc.nextstate];
185       auto &nr = radder_[arc.nextstate];
186       auto weight = Times(r, arc.weight);
187       if (!ApproxEqual(nd, Plus(nd, weight), delta_)) {
188         nd = na.Add(weight);
189         nr.Add(weight);
190         if (!nd.Member() || !nr.Sum().Member()) {
191           error_ = true;
192           return;
193         }
194         if (!enqueued_[arc.nextstate]) {
195           state_queue_->Enqueue(arc.nextstate);
196           enqueued_[arc.nextstate] = true;
197         } else {
198           state_queue_->Update(arc.nextstate);
199         }
200       }
201     }
202   }
203   ++source_id_;
204   if (fst_.Properties(kError, false)) error_ = true;
205 }
206
207 }  // namespace internal
208
209 // Shortest-distance algorithm: this version allows fine control
210 // via the options argument. See below for a simpler interface.
211 //
212 // This computes the shortest distance from the opts.source state to each
213 // visited state S and stores the value in the distance vector. An
214 // nvisited state S has distance Zero(), which will be stored in the
215 // distance vector if S is less than the maximum visited state. The state
216 // queue discipline, arc filter, and convergence delta are taken in the
217 // options argument. The distance vector will contain a unique element for
218 // which Member() is false if an error was encountered.
219 //
220 // The weights must must be right distributive and k-closed (i.e., 1 +
221 // x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
222 //
223 // Complexity:
224 //
225 // Depends on properties of the semiring and the queue discipline.
226 //
227 // For more information, see:
228 //
229 // Mohri, M. 2002. Semiring framework and algorithms for shortest-distance
230 // problems, Journal of Automata, Languages and
231 // Combinatorics 7(3): 321-350, 2002.
232 template <class Arc, class Queue, class ArcFilter>
233 void ShortestDistance(
234     const Fst<Arc> &fst, std::vector<typename Arc::Weight> *distance,
235     const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) {
236   internal::ShortestDistanceState<Arc, Queue, ArcFilter> sd_state(fst, distance,
237                                                                   opts, false);
238   sd_state.ShortestDistance(opts.source);
239   if (sd_state.Error()) {
240     distance->clear();
241     distance->resize(1, Arc::Weight::NoWeight());
242   }
243 }
244
245 // Shortest-distance algorithm: simplified interface. See above for a version
246 // that permits finer control.
247 //
248 // If reverse is false, this computes the shortest distance from the initial
249 // state to each state S and stores the value in the distance vector. If
250 // reverse is true, this computes the shortest distance from each state to the
251 // final states. An unvisited state S has distance Zero(), which will be stored
252 // in the distance vector if S is less than the maximum visited state. The
253 // state queue discipline is automatically-selected. The distance vector will
254 // contain a unique element for which Member() is false if an error was
255 // encountered.
256 //
257 // The weights must must be right (left) distributive if reverse is false (true)
258 // and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
259 //
260 // Arc weights must satisfy the property that the sum of the weights of one or
261 // more paths from some state S to T is never Zero(). In particular, arc weights
262 // are never Zero().
263 //
264 // Complexity:
265 //
266 // Depends on properties of the semiring and the queue discipline.
267 //
268 // For more information, see:
269 //
270 // Mohri, M. 2002. Semiring framework and algorithms for
271 // shortest-distance problems, Journal of Automata, Languages and
272 // Combinatorics 7(3): 321-350, 2002.
273 template <class Arc>
274 void ShortestDistance(const Fst<Arc> &fst,
275                       std::vector<typename Arc::Weight> *distance,
276                       bool reverse = false, float delta = kShortestDelta) {
277   using StateId = typename Arc::StateId;
278   using Weight = typename Arc::Weight;
279   if (!reverse) {
280     AnyArcFilter<Arc> arc_filter;
281     AutoQueue<StateId> state_queue(fst, distance, arc_filter);
282     const ShortestDistanceOptions<Arc, AutoQueue<StateId>, AnyArcFilter<Arc>>
283         opts(&state_queue, arc_filter, kNoStateId, delta);
284     ShortestDistance(fst, distance, opts);
285   } else {
286     using ReverseArc = ReverseArc<Arc>;
287     using ReverseWeight = typename ReverseArc::Weight;
288     AnyArcFilter<ReverseArc> rarc_filter;
289     VectorFst<ReverseArc> rfst;
290     Reverse(fst, &rfst);
291     std::vector<ReverseWeight> rdistance;
292     AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter);
293     const ShortestDistanceOptions<ReverseArc, AutoQueue<StateId>,
294                                   AnyArcFilter<ReverseArc>>
295         ropts(&state_queue, rarc_filter, kNoStateId, delta);
296     ShortestDistance(rfst, &rdistance, ropts);
297     distance->clear();
298     if (rdistance.size() == 1 && !rdistance[0].Member()) {
299       distance->resize(1, Arc::Weight::NoWeight());
300       return;
301     }
302     while (distance->size() < rdistance.size() - 1) {
303       distance->push_back(rdistance[distance->size() + 1].Reverse());
304     }
305   }
306 }
307
308 // Return the sum of the weight of all successful paths in an FST, i.e., the
309 // shortest-distance from the initial state to the final states. Returns a
310 // weight such that Member() is false if an error was encountered.
311 template <class Arc>
312 typename Arc::Weight ShortestDistance(const Fst<Arc> &fst,
313                                       float delta = kShortestDelta) {
314   using StateId = typename Arc::StateId;
315   using Weight = typename Arc::Weight;
316   std::vector<Weight> distance;
317   if (Weight::Properties() & kRightSemiring) {
318     ShortestDistance(fst, &distance, false, delta);
319     if (distance.size() == 1 && !distance[0].Member()) {
320       return Arc::Weight::NoWeight();
321     }
322     Adder<Weight> adder;  // maintains cumulative sum accurately
323     for (StateId state = 0; state < distance.size(); ++state) {
324       adder.Add(Times(distance[state], fst.Final(state)));
325     }
326     return adder.Sum();
327   } else {
328     ShortestDistance(fst, &distance, true, delta);
329     const auto state = fst.Start();
330     if (distance.size() == 1 && !distance[0].Member()) {
331       return Arc::Weight::NoWeight();
332     }
333     return state != kNoStateId && state < distance.size() ? distance[state]
334                                                           : Weight::Zero();
335   }
336 }
337
338 }  // namespace fst
339
340 #endif  // FST_SHORTEST_DISTANCE_H_