1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Functions and classes to find shortest distance in an FST.
6 #ifndef FST_SHORTEST_DISTANCE_H_
7 #define FST_SHORTEST_DISTANCE_H_
14 #include <fst/arcfilter.h>
15 #include <fst/cache.h>
16 #include <fst/queue.h>
17 #include <fst/reverse.h>
18 #include <fst/test-properties.h>
23 template <class Arc, class Queue, class ArcFilter>
24 struct ShortestDistanceOptions {
25 using StateId = typename Arc::StateId;
27 Queue *state_queue; // Queue discipline used; owned by caller.
28 ArcFilter arc_filter; // Arc filter (e.g., limit to only epsilon graph).
29 StateId source; // If kNoStateId, use the FST's initial state.
30 float delta; // Determines the degree of convergence required
31 bool first_path; // For a semiring with the path property (o.w.
32 // undefined), compute the shortest-distances along
33 // along the first path to a final state found
34 // by the algorithm. That path is the shortest-path
35 // only if the FST has a unique final state (or all
36 // the final states have the same final weight), the
37 // queue discipline is shortest-first and all the
38 // weights in the FST are between One() and Zero()
39 // according to NaturalLess.
41 ShortestDistanceOptions(Queue *state_queue, ArcFilter arc_filter,
42 StateId source = kNoStateId, float delta = kDelta)
43 : state_queue(state_queue),
44 arc_filter(arc_filter),
52 // Computation state of the shortest-distance algorithm. Reusable information
53 // is maintained across calls to member function ShortestDistance(source) when
54 // retain is true for improved efficiency when calling multiple times from
55 // different source states (e.g., in epsilon removal). Contrary to the usual
56 // conventions, fst may not be freed before this class. Vector distance
57 // should not be modified by the user between these calls. The Error() method
58 // returns true iff an error was encountered.
59 template <class Arc, class Queue, class ArcFilter>
60 class ShortestDistanceState {
62 using StateId = typename Arc::StateId;
63 using Weight = typename Arc::Weight;
65 ShortestDistanceState(
66 const Fst<Arc> &fst, std::vector<Weight> *distance,
67 const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts, bool retain)
70 state_queue_(opts.state_queue),
71 arc_filter_(opts.arc_filter),
73 first_path_(opts.first_path),
80 void ShortestDistance(StateId source);
82 bool Error() const { return error_; }
86 std::vector<Weight> *distance_;
88 ArcFilter arc_filter_;
90 const bool first_path_;
91 const bool retain_; // Retain and reuse information across calls.
93 std::vector<Adder<Weight>> adder_; // Sums distance_ accurately.
94 std::vector<Adder<Weight>> radder_; // Relaxation distance.
95 std::vector<bool> enqueued_; // Is state enqueued?
96 std::vector<StateId> sources_; // Source ID for ith state in distance_,
97 // (r)adder_, and enqueued_ if retained.
98 StateId source_id_; // Unique ID characterizing each call.
102 // Compute the shortest distance; if source is kNoStateId, uses the initial
104 template <class Arc, class Queue, class ArcFilter>
105 void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance(
107 if (fst_.Start() == kNoStateId) {
108 if (fst_.Properties(kError, false)) error_ = true;
111 if (!(Weight::Properties() & kRightSemiring)) {
112 FSTERROR() << "ShortestDistance: Weight needs to be right distributive: "
117 if (first_path_ && !(Weight::Properties() & kPath)) {
118 FSTERROR() << "ShortestDistance: The first_path option is disallowed when "
119 << "Weight does not have the path property: " << Weight::Type();
123 state_queue_->Clear();
130 if (source == kNoStateId) source = fst_.Start();
131 while (distance_->size() <= source) {
132 distance_->push_back(Weight::Zero());
133 adder_.push_back(Adder<Weight>());
134 radder_.push_back(Adder<Weight>());
135 enqueued_.push_back(false);
138 while (sources_.size() <= source) sources_.push_back(kNoStateId);
139 sources_[source] = source_id_;
141 (*distance_)[source] = Weight::One();
142 adder_[source].Reset(Weight::One());
143 radder_[source].Reset(Weight::One());
144 enqueued_[source] = true;
145 state_queue_->Enqueue(source);
146 while (!state_queue_->Empty()) {
147 const auto state = state_queue_->Head();
148 state_queue_->Dequeue();
149 while (distance_->size() <= state) {
150 distance_->push_back(Weight::Zero());
151 adder_.push_back(Adder<Weight>());
152 radder_.push_back(Adder<Weight>());
153 enqueued_.push_back(false);
155 if (first_path_ && (fst_.Final(state) != Weight::Zero())) break;
156 enqueued_[state] = false;
157 const auto r = radder_[state].Sum();
158 radder_[state].Reset();
159 for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
161 const auto &arc = aiter.Value();
162 if (!arc_filter_(arc)) continue;
163 while (distance_->size() <= arc.nextstate) {
164 distance_->push_back(Weight::Zero());
165 adder_.push_back(Adder<Weight>());
166 radder_.push_back(Adder<Weight>());
167 enqueued_.push_back(false);
170 while (sources_.size() <= arc.nextstate) sources_.push_back(kNoStateId);
171 if (sources_[arc.nextstate] != source_id_) {
172 (*distance_)[arc.nextstate] = Weight::Zero();
173 adder_[arc.nextstate].Reset();
174 radder_[arc.nextstate].Reset();
175 enqueued_[arc.nextstate] = false;
176 sources_[arc.nextstate] = source_id_;
179 auto &nd = (*distance_)[arc.nextstate];
180 auto &na = adder_[arc.nextstate];
181 auto &nr = radder_[arc.nextstate];
182 auto weight = Times(r, arc.weight);
183 if (!ApproxEqual(nd, Plus(nd, weight), delta_)) {
186 if (!nd.Member() || !nr.Sum().Member()) {
190 if (!enqueued_[arc.nextstate]) {
191 state_queue_->Enqueue(arc.nextstate);
192 enqueued_[arc.nextstate] = true;
194 state_queue_->Update(arc.nextstate);
200 if (fst_.Properties(kError, false)) error_ = true;
203 } // namespace internal
205 // Shortest-distance algorithm: this version allows fine control
206 // via the options argument. See below for a simpler interface.
208 // This computes the shortest distance from the opts.source state to each
209 // visited state S and stores the value in the distance vector. An
210 // nvisited state S has distance Zero(), which will be stored in the
211 // distance vector if S is less than the maximum visited state. The state
212 // queue discipline, arc filter, and convergence delta are taken in the
213 // options argument. The distance vector will contain a unique element for
214 // which Member() is false if an error was encountered.
216 // The weights must must be right distributive and k-closed (i.e., 1 +
217 // x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
221 // Depends on properties of the semiring and the queue discipline.
223 // For more information, see:
225 // Mohri, M. 2002. Semiring framework and algorithms for shortest-distance
226 // problems, Journal of Automata, Languages and
227 // Combinatorics 7(3): 321-350, 2002.
228 template <class Arc, class Queue, class ArcFilter>
229 void ShortestDistance(
230 const Fst<Arc> &fst, std::vector<typename Arc::Weight> *distance,
231 const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) {
232 internal::ShortestDistanceState<Arc, Queue, ArcFilter> sd_state(fst, distance,
234 sd_state.ShortestDistance(opts.source);
235 if (sd_state.Error()) {
237 distance->resize(1, Arc::Weight::NoWeight());
241 // Shortest-distance algorithm: simplified interface. See above for a version
242 // that permits finer control.
244 // If reverse is false, this computes the shortest distance from the initial
245 // state to each state S and stores the value in the distance vector. If
246 // reverse is true, this computes the shortest distance from each state to the
247 // final states. An unvisited state S has distance Zero(), which will be stored
248 // in the distance vector if S is less than the maximum visited state. The
249 // state queue discipline is automatically-selected. The distance vector will
250 // contain a unique element for which Member() is false if an error was
253 // The weights must must be right (left) distributive if reverse is false (true)
254 // and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
256 // Arc weights must satisfy the property that the sum of the weights of one or
257 // more paths from some state S to T is never Zero(). In particular, arc weights
262 // Depends on properties of the semiring and the queue discipline.
264 // For more information, see:
266 // Mohri, M. 2002. Semiring framework and algorithms for
267 // shortest-distance problems, Journal of Automata, Languages and
268 // Combinatorics 7(3): 321-350, 2002.
270 void ShortestDistance(const Fst<Arc> &fst,
271 std::vector<typename Arc::Weight> *distance,
272 bool reverse = false, float delta = kDelta) {
273 using StateId = typename Arc::StateId;
274 using Weight = typename Arc::Weight;
276 AnyArcFilter<Arc> arc_filter;
277 AutoQueue<StateId> state_queue(fst, distance, arc_filter);
278 const ShortestDistanceOptions<Arc, AutoQueue<StateId>, AnyArcFilter<Arc>>
279 opts(&state_queue, arc_filter, kNoStateId, delta);
280 ShortestDistance(fst, distance, opts);
282 using ReverseArc = ReverseArc<Arc>;
283 using ReverseWeight = typename ReverseArc::Weight;
284 AnyArcFilter<ReverseArc> rarc_filter;
285 VectorFst<ReverseArc> rfst;
287 std::vector<ReverseWeight> rdistance;
288 AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter);
289 const ShortestDistanceOptions<ReverseArc, AutoQueue<StateId>,
290 AnyArcFilter<ReverseArc>>
291 ropts(&state_queue, rarc_filter, kNoStateId, delta);
292 ShortestDistance(rfst, &rdistance, ropts);
294 if (rdistance.size() == 1 && !rdistance[0].Member()) {
295 distance->resize(1, Arc::Weight::NoWeight());
298 while (distance->size() < rdistance.size() - 1) {
299 distance->push_back(rdistance[distance->size() + 1].Reverse());
304 // Return the sum of the weight of all successful paths in an FST, i.e., the
305 // shortest-distance from the initial state to the final states. Returns a
306 // weight such that Member() is false if an error was encountered.
308 typename Arc::Weight ShortestDistance(const Fst<Arc> &fst,
309 float delta = kDelta) {
310 using StateId = typename Arc::StateId;
311 using Weight = typename Arc::Weight;
312 std::vector<Weight> distance;
313 if (Weight::Properties() & kRightSemiring) {
314 ShortestDistance(fst, &distance, false, delta);
315 if (distance.size() == 1 && !distance[0].Member()) {
316 return Arc::Weight::NoWeight();
318 Adder<Weight> adder; // maintains cumulative sum accurately
319 for (StateId state = 0; state < distance.size(); ++state) {
320 adder.Add(Times(distance[state], fst.Final(state)));
324 ShortestDistance(fst, &distance, true, delta);
325 const auto state = fst.Start();
326 if (distance.size() == 1 && !distance[0].Member()) {
327 return Arc::Weight::NoWeight();
329 return state != kNoStateId && state < distance.size() ? distance[state]
336 #endif // FST_SHORTEST_DISTANCE_H_