Imported Upstream version 1.6.6
[platform/upstream/openfst.git] / src / include / fst / script / shortest-distance.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3
4 #ifndef FST_SCRIPT_SHORTEST_DISTANCE_H_
5 #define FST_SCRIPT_SHORTEST_DISTANCE_H_
6
7 #include <tuple>
8 #include <vector>
9
10 #include <fst/queue.h>
11 #include <fst/shortest-distance.h>
12 #include <fst/script/fst-class.h>
13 #include <fst/script/prune.h>
14 #include <fst/script/script-impl.h>
15 #include <fst/script/weight-class.h>
16
17 namespace fst {
18 namespace script {
19
20 enum ArcFilterType {
21   ANY_ARC_FILTER,
22   EPSILON_ARC_FILTER,
23   INPUT_EPSILON_ARC_FILTER,
24   OUTPUT_EPSILON_ARC_FILTER
25 };
26
27 struct ShortestDistanceOptions {
28   const QueueType queue_type;
29   const ArcFilterType arc_filter_type;
30   const int64 source;
31   const float delta;
32
33   ShortestDistanceOptions(QueueType queue_type, ArcFilterType arc_filter_type,
34                           int64 source, float delta)
35       : queue_type(queue_type),
36         arc_filter_type(arc_filter_type),
37         source(source),
38         delta(delta) {}
39 };
40
41 namespace internal {
42
43 // Code to implement switching on queue and arc filter types.
44
45 template <class Arc, class Queue, class ArcFilter>
46 struct QueueConstructor {
47   using Weight = typename Arc::Weight;
48
49   static Queue *Construct(const Fst<Arc> &, const std::vector<Weight> *) {
50     return new Queue();
51   }
52 };
53
54 // Specializations to support queues with different constructors.
55
56 template <class Arc, class ArcFilter>
57 struct QueueConstructor<Arc, AutoQueue<typename Arc::StateId>, ArcFilter> {
58   using StateId = typename Arc::StateId;
59   using Weight = typename Arc::Weight;
60
61   //  template<class Arc, class ArcFilter>
62   static AutoQueue<StateId> *Construct(const Fst<Arc> &fst,
63                                        const std::vector<Weight> *distance) {
64     return new AutoQueue<StateId>(fst, distance, ArcFilter());
65   }
66 };
67
68 template <class Arc, class ArcFilter>
69 struct QueueConstructor<
70     Arc, NaturalShortestFirstQueue<typename Arc::StateId, typename Arc::Weight>,
71     ArcFilter> {
72   using StateId = typename Arc::StateId;
73   using Weight = typename Arc::Weight;
74
75   static NaturalShortestFirstQueue<StateId, Weight> *Construct(
76       const Fst<Arc> &, const std::vector<Weight> *distance) {
77     return new NaturalShortestFirstQueue<StateId, Weight>(*distance);
78   }
79 };
80
81 template <class Arc, class ArcFilter>
82 struct QueueConstructor<Arc, TopOrderQueue<typename Arc::StateId>, ArcFilter> {
83   using StateId = typename Arc::StateId;
84   using Weight = typename Arc::Weight;
85
86   static TopOrderQueue<StateId> *Construct(const Fst<Arc> &fst,
87                                            const std::vector<Weight> *) {
88     return new TopOrderQueue<StateId>(fst, ArcFilter());
89   }
90 };
91
92 template <class Arc, class Queue, class ArcFilter>
93 void ShortestDistance(const Fst<Arc> &fst,
94                       std::vector<typename Arc::Weight> *distance,
95                       const ShortestDistanceOptions &opts) {
96   std::unique_ptr<Queue> queue(
97       QueueConstructor<Arc, Queue, ArcFilter>::Construct(fst, distance));
98   const fst::ShortestDistanceOptions<Arc, Queue, ArcFilter> sopts(
99       queue.get(), ArcFilter(), opts.source, opts.delta);
100   ShortestDistance(fst, distance, sopts);
101 }
102
103 template <class Arc, class Queue>
104 void ShortestDistance(const Fst<Arc> &fst,
105                       std::vector<typename Arc::Weight> *distance,
106                       const ShortestDistanceOptions &opts) {
107   switch (opts.arc_filter_type) {
108     case ANY_ARC_FILTER: {
109       ShortestDistance<Arc, Queue, AnyArcFilter<Arc>>(fst, distance, opts);
110       return;
111     }
112     case EPSILON_ARC_FILTER: {
113       ShortestDistance<Arc, Queue, EpsilonArcFilter<Arc>>(fst, distance, opts);
114       return;
115     }
116     case INPUT_EPSILON_ARC_FILTER: {
117       ShortestDistance<Arc, Queue, InputEpsilonArcFilter<Arc>>(fst, distance,
118                                                                opts);
119       return;
120     }
121     case OUTPUT_EPSILON_ARC_FILTER: {
122       ShortestDistance<Arc, Queue, OutputEpsilonArcFilter<Arc>>(fst, distance,
123                                                                 opts);
124       return;
125     }
126     default: {
127       FSTERROR() << "ShortestDistance: Unknown arc filter type: "
128                  << opts.arc_filter_type;
129       distance->clear();
130       distance->resize(1, Arc::Weight::NoWeight());
131       return;
132     }
133   }
134 }
135
136 }  // namespace internal
137
138 using ShortestDistanceArgs1 =
139     std::tuple<const FstClass &, std::vector<WeightClass> *,
140                const ShortestDistanceOptions &>;
141
142 template <class Arc>
143 void ShortestDistance(ShortestDistanceArgs1 *args) {
144   using StateId = typename Arc::StateId;
145   using Weight = typename Arc::Weight;
146   const Fst<Arc> &fst = *(std::get<0>(*args).GetFst<Arc>());
147   const auto &opts = std::get<2>(*args);
148   std::vector<Weight> typed_distance;
149   switch (opts.queue_type) {
150     case AUTO_QUEUE: {
151       internal::ShortestDistance<Arc, AutoQueue<StateId>>(fst, &typed_distance,
152                                                           opts);
153       break;
154     }
155     case FIFO_QUEUE: {
156       internal::ShortestDistance<Arc, FifoQueue<StateId>>(fst, &typed_distance,
157                                                           opts);
158       break;
159     }
160     case LIFO_QUEUE: {
161       internal::ShortestDistance<Arc, LifoQueue<StateId>>(fst, &typed_distance,
162                                                           opts);
163       break;
164     }
165     case SHORTEST_FIRST_QUEUE: {
166       internal::ShortestDistance<Arc,
167                                  NaturalShortestFirstQueue<StateId, Weight>>(
168           fst, &typed_distance, opts);
169       break;
170     }
171     case STATE_ORDER_QUEUE: {
172       internal::ShortestDistance<Arc, StateOrderQueue<StateId>>(
173           fst, &typed_distance, opts);
174       break;
175     }
176     case TOP_ORDER_QUEUE: {
177       internal::ShortestDistance<Arc, TopOrderQueue<StateId>>(
178           fst, &typed_distance, opts);
179       break;
180     }
181     default: {
182       FSTERROR() << "ShortestDistance: Unknown queue type: " << opts.queue_type;
183       typed_distance.clear();
184       typed_distance.resize(1, Arc::Weight::NoWeight());
185       break;
186     }
187   }
188   internal::CopyWeights(typed_distance, std::get<1>(*args));
189 }
190
191 using ShortestDistanceArgs2 =
192     std::tuple<const FstClass &, std::vector<WeightClass> *, bool, double>;
193
194 template <class Arc>
195 void ShortestDistance(ShortestDistanceArgs2 *args) {
196   using Weight = typename Arc::Weight;
197   const Fst<Arc> &fst = *(std::get<0>(*args).GetFst<Arc>());
198   std::vector<Weight> typed_distance;
199   ShortestDistance(fst, &typed_distance, std::get<2>(*args),
200                    std::get<3>(*args));
201   internal::CopyWeights(typed_distance, std::get<1>(*args));
202 }
203
204 void ShortestDistance(const FstClass &fst, std::vector<WeightClass> *distance,
205                       const ShortestDistanceOptions &opts);
206
207 void ShortestDistance(const FstClass &ifst, std::vector<WeightClass> *distance,
208                       bool reverse = false,
209                       double delta = fst::kShortestDelta);
210
211 }  // namespace script
212 }  // namespace fst
213
214 #endif  // FST_SCRIPT_SHORTEST_DISTANCE_H_