1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 #ifndef FST_SCRIPT_SHORTEST_DISTANCE_H_
5 #define FST_SCRIPT_SHORTEST_DISTANCE_H_
10 #include <fst/queue.h>
11 #include <fst/shortest-distance.h>
12 #include <fst/script/fst-class.h>
13 #include <fst/script/prune.h>
14 #include <fst/script/script-impl.h>
15 #include <fst/script/weight-class.h>
23 INPUT_EPSILON_ARC_FILTER,
24 OUTPUT_EPSILON_ARC_FILTER
27 struct ShortestDistanceOptions {
28 const QueueType queue_type;
29 const ArcFilterType arc_filter_type;
33 ShortestDistanceOptions(QueueType queue_type, ArcFilterType arc_filter_type,
34 int64 source, float delta)
35 : queue_type(queue_type),
36 arc_filter_type(arc_filter_type),
43 // Code to implement switching on queue and arc filter types.
45 template <class Arc, class Queue, class ArcFilter>
46 struct QueueConstructor {
47 using Weight = typename Arc::Weight;
49 static Queue *Construct(const Fst<Arc> &, const std::vector<Weight> *) {
54 // Specializations to support queues with different constructors.
56 template <class Arc, class ArcFilter>
57 struct QueueConstructor<Arc, AutoQueue<typename Arc::StateId>, ArcFilter> {
58 using StateId = typename Arc::StateId;
59 using Weight = typename Arc::Weight;
61 // template<class Arc, class ArcFilter>
62 static AutoQueue<StateId> *Construct(const Fst<Arc> &fst,
63 const std::vector<Weight> *distance) {
64 return new AutoQueue<StateId>(fst, distance, ArcFilter());
68 template <class Arc, class ArcFilter>
69 struct QueueConstructor<
70 Arc, NaturalShortestFirstQueue<typename Arc::StateId, typename Arc::Weight>,
72 using StateId = typename Arc::StateId;
73 using Weight = typename Arc::Weight;
75 static NaturalShortestFirstQueue<StateId, Weight> *Construct(
76 const Fst<Arc> &, const std::vector<Weight> *distance) {
77 return new NaturalShortestFirstQueue<StateId, Weight>(*distance);
81 template <class Arc, class ArcFilter>
82 struct QueueConstructor<Arc, TopOrderQueue<typename Arc::StateId>, ArcFilter> {
83 using StateId = typename Arc::StateId;
84 using Weight = typename Arc::Weight;
86 static TopOrderQueue<StateId> *Construct(const Fst<Arc> &fst,
87 const std::vector<Weight> *) {
88 return new TopOrderQueue<StateId>(fst, ArcFilter());
92 template <class Arc, class Queue, class ArcFilter>
93 void ShortestDistance(const Fst<Arc> &fst,
94 std::vector<typename Arc::Weight> *distance,
95 const ShortestDistanceOptions &opts) {
96 std::unique_ptr<Queue> queue(
97 QueueConstructor<Arc, Queue, ArcFilter>::Construct(fst, distance));
98 const fst::ShortestDistanceOptions<Arc, Queue, ArcFilter> sopts(
99 queue.get(), ArcFilter(), opts.source, opts.delta);
100 ShortestDistance(fst, distance, sopts);
103 template <class Arc, class Queue>
104 void ShortestDistance(const Fst<Arc> &fst,
105 std::vector<typename Arc::Weight> *distance,
106 const ShortestDistanceOptions &opts) {
107 switch (opts.arc_filter_type) {
108 case ANY_ARC_FILTER: {
109 ShortestDistance<Arc, Queue, AnyArcFilter<Arc>>(fst, distance, opts);
112 case EPSILON_ARC_FILTER: {
113 ShortestDistance<Arc, Queue, EpsilonArcFilter<Arc>>(fst, distance, opts);
116 case INPUT_EPSILON_ARC_FILTER: {
117 ShortestDistance<Arc, Queue, InputEpsilonArcFilter<Arc>>(fst, distance,
121 case OUTPUT_EPSILON_ARC_FILTER: {
122 ShortestDistance<Arc, Queue, OutputEpsilonArcFilter<Arc>>(fst, distance,
127 FSTERROR() << "ShortestDistance: Unknown arc filter type: "
128 << opts.arc_filter_type;
130 distance->resize(1, Arc::Weight::NoWeight());
136 } // namespace internal
138 using ShortestDistanceArgs1 =
139 std::tuple<const FstClass &, std::vector<WeightClass> *,
140 const ShortestDistanceOptions &>;
143 void ShortestDistance(ShortestDistanceArgs1 *args) {
144 using StateId = typename Arc::StateId;
145 using Weight = typename Arc::Weight;
146 const Fst<Arc> &fst = *(std::get<0>(*args).GetFst<Arc>());
147 const auto &opts = std::get<2>(*args);
148 std::vector<Weight> typed_distance;
149 switch (opts.queue_type) {
151 internal::ShortestDistance<Arc, AutoQueue<StateId>>(fst, &typed_distance,
156 internal::ShortestDistance<Arc, FifoQueue<StateId>>(fst, &typed_distance,
161 internal::ShortestDistance<Arc, LifoQueue<StateId>>(fst, &typed_distance,
165 case SHORTEST_FIRST_QUEUE: {
166 internal::ShortestDistance<Arc,
167 NaturalShortestFirstQueue<StateId, Weight>>(
168 fst, &typed_distance, opts);
171 case STATE_ORDER_QUEUE: {
172 internal::ShortestDistance<Arc, StateOrderQueue<StateId>>(
173 fst, &typed_distance, opts);
176 case TOP_ORDER_QUEUE: {
177 internal::ShortestDistance<Arc, TopOrderQueue<StateId>>(
178 fst, &typed_distance, opts);
182 FSTERROR() << "ShortestDistance: Unknown queue type: " << opts.queue_type;
183 typed_distance.clear();
184 typed_distance.resize(1, Arc::Weight::NoWeight());
188 internal::CopyWeights(typed_distance, std::get<1>(*args));
191 using ShortestDistanceArgs2 =
192 std::tuple<const FstClass &, std::vector<WeightClass> *, bool, double>;
195 void ShortestDistance(ShortestDistanceArgs2 *args) {
196 using Weight = typename Arc::Weight;
197 const Fst<Arc> &fst = *(std::get<0>(*args).GetFst<Arc>());
198 std::vector<Weight> typed_distance;
199 ShortestDistance(fst, &typed_distance, std::get<2>(*args),
201 internal::CopyWeights(typed_distance, std::get<1>(*args));
204 void ShortestDistance(const FstClass &fst, std::vector<WeightClass> *distance,
205 const ShortestDistanceOptions &opts);
207 void ShortestDistance(const FstClass &ifst, std::vector<WeightClass> *distance,
208 bool reverse = false,
209 double delta = fst::kShortestDelta);
211 } // namespace script
214 #endif // FST_SCRIPT_SHORTEST_DISTANCE_H_