3bdb37536094e19c6a934d919a848d37ad867250
[platform/upstream/openfst.git] / src / include / fst / shortest-distance.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions and classes to find shortest distance in an FST.
5
6 #ifndef FST_SHORTEST_DISTANCE_H_
7 #define FST_SHORTEST_DISTANCE_H_
8
9 #include <deque>
10 #include <vector>
11
12 #include <fst/log.h>
13
14 #include <fst/arcfilter.h>
15 #include <fst/cache.h>
16 #include <fst/queue.h>
17 #include <fst/reverse.h>
18 #include <fst/test-properties.h>
19
20
21 namespace fst {
22
23 template <class Arc, class Queue, class ArcFilter>
24 struct ShortestDistanceOptions {
25   using StateId = typename Arc::StateId;
26
27   Queue *state_queue;    // Queue discipline used; owned by caller.
28   ArcFilter arc_filter;  // Arc filter (e.g., limit to only epsilon graph).
29   StateId source;        // If kNoStateId, use the FST's initial state.
30   float delta;           // Determines the degree of convergence required
31   bool first_path;       // For a semiring with the path property (o.w.
32                          // undefined), compute the shortest-distances along
33                          // along the first path to a final state found
34                          // by the algorithm. That path is the shortest-path
35                          // only if the FST has a unique final state (or all
36                          // the final states have the same final weight), the
37                          // queue discipline is shortest-first and all the
38                          // weights in the FST are between One() and Zero()
39                          // according to NaturalLess.
40
41   ShortestDistanceOptions(Queue *state_queue, ArcFilter arc_filter,
42                           StateId source = kNoStateId, float delta = kDelta)
43       : state_queue(state_queue),
44         arc_filter(arc_filter),
45         source(source),
46         delta(delta),
47         first_path(false) {}
48 };
49
50 namespace internal {
51
52 // Computation state of the shortest-distance algorithm. Reusable information
53 // is maintained across calls to member function ShortestDistance(source) when
54 // retain is true for improved efficiency when calling multiple times from
55 // different source states (e.g., in epsilon removal). Contrary to the usual
56 // conventions, fst may not be freed before this class. Vector distance
57 // should not be modified by the user between these calls. The Error() method
58 // returns true iff an error was encountered.
59 template <class Arc, class Queue, class ArcFilter>
60 class ShortestDistanceState {
61  public:
62   using StateId = typename Arc::StateId;
63   using Weight = typename Arc::Weight;
64
65   ShortestDistanceState(
66       const Fst<Arc> &fst, std::vector<Weight> *distance,
67       const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts, bool retain)
68       : fst_(fst),
69         distance_(distance),
70         state_queue_(opts.state_queue),
71         arc_filter_(opts.arc_filter),
72         delta_(opts.delta),
73         first_path_(opts.first_path),
74         retain_(retain),
75         source_id_(0),
76         error_(false) {
77     distance_->clear();
78   }
79
80   void ShortestDistance(StateId source);
81
82   bool Error() const { return error_; }
83
84  private:
85   const Fst<Arc> &fst_;
86   std::vector<Weight> *distance_;
87   Queue *state_queue_;
88   ArcFilter arc_filter_;
89   const float delta_;
90   const bool first_path_;
91   const bool retain_;  // Retain and reuse information across calls.
92
93   std::vector<Adder<Weight>> adder_;   // Sums distance_ accurately.
94   std::vector<Adder<Weight>> radder_;  // Relaxation distance.
95   std::vector<bool> enqueued_;     // Is state enqueued?
96   std::vector<StateId> sources_;   // Source ID for ith state in distance_,
97                                    // (r)adder_, and enqueued_ if retained.
98   StateId source_id_;              // Unique ID characterizing each call.
99   bool error_;
100 };
101
102 // Compute the shortest distance; if source is kNoStateId, uses the initial
103 // state of the FST.
104 template <class Arc, class Queue, class ArcFilter>
105 void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance(
106     StateId source) {
107   if (fst_.Start() == kNoStateId) {
108     if (fst_.Properties(kError, false)) error_ = true;
109     return;
110   }
111   if (!(Weight::Properties() & kRightSemiring)) {
112     FSTERROR() << "ShortestDistance: Weight needs to be right distributive: "
113                << Weight::Type();
114     error_ = true;
115     return;
116   }
117   if (first_path_ && !(Weight::Properties() & kPath)) {
118     FSTERROR() << "ShortestDistance: The first_path option is disallowed when "
119                << "Weight does not have the path property: " << Weight::Type();
120     error_ = true;
121     return;
122   }
123   state_queue_->Clear();
124   if (!retain_) {
125     distance_->clear();
126     adder_.clear();
127     radder_.clear();
128     enqueued_.clear();
129   }
130   if (source == kNoStateId) source = fst_.Start();
131   while (distance_->size() <= source) {
132     distance_->push_back(Weight::Zero());
133     adder_.push_back(Adder<Weight>());
134     radder_.push_back(Adder<Weight>());
135     enqueued_.push_back(false);
136   }
137   if (retain_) {
138     while (sources_.size() <= source) sources_.push_back(kNoStateId);
139     sources_[source] = source_id_;
140   }
141   (*distance_)[source] = Weight::One();
142   adder_[source].Reset(Weight::One());
143   radder_[source].Reset(Weight::One());
144   enqueued_[source] = true;
145   state_queue_->Enqueue(source);
146   while (!state_queue_->Empty()) {
147     const auto state = state_queue_->Head();
148     state_queue_->Dequeue();
149     while (distance_->size() <= state) {
150       distance_->push_back(Weight::Zero());
151       adder_.push_back(Adder<Weight>());
152       radder_.push_back(Adder<Weight>());
153       enqueued_.push_back(false);
154     }
155     if (first_path_ && (fst_.Final(state) != Weight::Zero())) break;
156     enqueued_[state] = false;
157     const auto r = radder_[state].Sum();
158     radder_[state].Reset();
159     for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
160          aiter.Next()) {
161       const auto &arc = aiter.Value();
162       if (!arc_filter_(arc)) continue;
163       while (distance_->size() <= arc.nextstate) {
164         distance_->push_back(Weight::Zero());
165         adder_.push_back(Adder<Weight>());
166         radder_.push_back(Adder<Weight>());
167         enqueued_.push_back(false);
168       }
169       if (retain_) {
170         while (sources_.size() <= arc.nextstate) sources_.push_back(kNoStateId);
171         if (sources_[arc.nextstate] != source_id_) {
172           (*distance_)[arc.nextstate] = Weight::Zero();
173           adder_[arc.nextstate].Reset();
174           radder_[arc.nextstate].Reset();
175           enqueued_[arc.nextstate] = false;
176           sources_[arc.nextstate] = source_id_;
177         }
178       }
179       auto &nd = (*distance_)[arc.nextstate];
180       auto &na = adder_[arc.nextstate];
181       auto &nr = radder_[arc.nextstate];
182       auto weight = Times(r, arc.weight);
183       if (!ApproxEqual(nd, Plus(nd, weight), delta_)) {
184         nd = na.Add(weight);
185         nr.Add(weight);
186         if (!nd.Member() || !nr.Sum().Member()) {
187           error_ = true;
188           return;
189         }
190         if (!enqueued_[arc.nextstate]) {
191           state_queue_->Enqueue(arc.nextstate);
192           enqueued_[arc.nextstate] = true;
193         } else {
194           state_queue_->Update(arc.nextstate);
195         }
196       }
197     }
198   }
199   ++source_id_;
200   if (fst_.Properties(kError, false)) error_ = true;
201 }
202
203 }  // namespace internal
204
205 // Shortest-distance algorithm: this version allows fine control
206 // via the options argument. See below for a simpler interface.
207 //
208 // This computes the shortest distance from the opts.source state to each
209 // visited state S and stores the value in the distance vector. An
210 // nvisited state S has distance Zero(), which will be stored in the
211 // distance vector if S is less than the maximum visited state. The state
212 // queue discipline, arc filter, and convergence delta are taken in the
213 // options argument. The distance vector will contain a unique element for
214 // which Member() is false if an error was encountered.
215 //
216 // The weights must must be right distributive and k-closed (i.e., 1 +
217 // x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
218 //
219 // Complexity:
220 //
221 // Depends on properties of the semiring and the queue discipline.
222 //
223 // For more information, see:
224 //
225 // Mohri, M. 2002. Semiring framework and algorithms for shortest-distance
226 // problems, Journal of Automata, Languages and
227 // Combinatorics 7(3): 321-350, 2002.
228 template <class Arc, class Queue, class ArcFilter>
229 void ShortestDistance(
230     const Fst<Arc> &fst, std::vector<typename Arc::Weight> *distance,
231     const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) {
232   internal::ShortestDistanceState<Arc, Queue, ArcFilter> sd_state(fst, distance,
233                                                                   opts, false);
234   sd_state.ShortestDistance(opts.source);
235   if (sd_state.Error()) {
236     distance->clear();
237     distance->resize(1, Arc::Weight::NoWeight());
238   }
239 }
240
241 // Shortest-distance algorithm: simplified interface. See above for a version
242 // that permits finer control.
243 //
244 // If reverse is false, this computes the shortest distance from the initial
245 // state to each state S and stores the value in the distance vector. If
246 // reverse is true, this computes the shortest distance from each state to the
247 // final states. An unvisited state S has distance Zero(), which will be stored
248 // in the distance vector if S is less than the maximum visited state. The
249 // state queue discipline is automatically-selected. The distance vector will
250 // contain a unique element for which Member() is false if an error was
251 // encountered.
252 //
253 // The weights must must be right (left) distributive if reverse is false (true)
254 // and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
255 //
256 // Arc weights must satisfy the property that the sum of the weights of one or
257 // more paths from some state S to T is never Zero(). In particular, arc weights
258 // are never Zero().
259 //
260 // Complexity:
261 //
262 // Depends on properties of the semiring and the queue discipline.
263 //
264 // For more information, see:
265 //
266 // Mohri, M. 2002. Semiring framework and algorithms for
267 // shortest-distance problems, Journal of Automata, Languages and
268 // Combinatorics 7(3): 321-350, 2002.
269 template <class Arc>
270 void ShortestDistance(const Fst<Arc> &fst,
271                       std::vector<typename Arc::Weight> *distance,
272                       bool reverse = false, float delta = kDelta) {
273   using StateId = typename Arc::StateId;
274   using Weight = typename Arc::Weight;
275   if (!reverse) {
276     AnyArcFilter<Arc> arc_filter;
277     AutoQueue<StateId> state_queue(fst, distance, arc_filter);
278     const ShortestDistanceOptions<Arc, AutoQueue<StateId>, AnyArcFilter<Arc>>
279         opts(&state_queue, arc_filter, kNoStateId, delta);
280     ShortestDistance(fst, distance, opts);
281   } else {
282     using ReverseArc = ReverseArc<Arc>;
283     using ReverseWeight = typename ReverseArc::Weight;
284     AnyArcFilter<ReverseArc> rarc_filter;
285     VectorFst<ReverseArc> rfst;
286     Reverse(fst, &rfst);
287     std::vector<ReverseWeight> rdistance;
288     AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter);
289     const ShortestDistanceOptions<ReverseArc, AutoQueue<StateId>,
290                                   AnyArcFilter<ReverseArc>>
291         ropts(&state_queue, rarc_filter, kNoStateId, delta);
292     ShortestDistance(rfst, &rdistance, ropts);
293     distance->clear();
294     if (rdistance.size() == 1 && !rdistance[0].Member()) {
295       distance->resize(1, Arc::Weight::NoWeight());
296       return;
297     }
298     while (distance->size() < rdistance.size() - 1) {
299       distance->push_back(rdistance[distance->size() + 1].Reverse());
300     }
301   }
302 }
303
304 // Return the sum of the weight of all successful paths in an FST, i.e., the
305 // shortest-distance from the initial state to the final states. Returns a
306 // weight such that Member() is false if an error was encountered.
307 template <class Arc>
308 typename Arc::Weight ShortestDistance(const Fst<Arc> &fst,
309                                       float delta = kDelta) {
310   using StateId = typename Arc::StateId;
311   using Weight = typename Arc::Weight;
312   std::vector<Weight> distance;
313   if (Weight::Properties() & kRightSemiring) {
314     ShortestDistance(fst, &distance, false, delta);
315     if (distance.size() == 1 && !distance[0].Member()) {
316       return Arc::Weight::NoWeight();
317     }
318     Adder<Weight> adder;  // maintains cumulative sum accurately
319     for (StateId state = 0; state < distance.size(); ++state) {
320       adder.Add(Times(distance[state], fst.Final(state)));
321     }
322     return adder.Sum();
323   } else {
324     ShortestDistance(fst, &distance, true, delta);
325     const auto state = fst.Start();
326     if (distance.size() == 1 && !distance[0].Member()) {
327       return Arc::Weight::NoWeight();
328     }
329     return state != kNoStateId && state < distance.size() ? distance[state]
330                                                           : Weight::Zero();
331   }
332 }
333
334 }  // namespace fst
335
336 #endif  // FST_SHORTEST_DISTANCE_H_