1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Functions implementing pruning.
6 #ifndef FST_LIB_PRUNE_H_
7 #define FST_LIB_PRUNE_H_
14 #include <fst/arcfilter.h>
16 #include <fst/shortest-distance.h>
22 template <class StateId, class Weight>
25 PruneCompare(const std::vector<Weight> &idistance,
26 const std::vector<Weight> &fdistance)
27 : idistance_(idistance), fdistance_(fdistance) {}
29 bool operator()(const StateId x, const StateId y) const {
30 const auto wx = Times(IDistance(x), FDistance(x));
31 const auto wy = Times(IDistance(y), FDistance(y));
36 Weight IDistance(const StateId s) const {
37 return s < idistance_.size() ? idistance_[s] : Weight::Zero();
40 Weight FDistance(const StateId s) const {
41 return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
44 const std::vector<Weight> &idistance_;
45 const std::vector<Weight> &fdistance_;
46 NaturalLess<Weight> less_;
49 } // namespace internal
51 template <class Arc, class ArcFilter>
53 using StateId = typename Arc::StateId;
54 using Weight = typename Arc::Weight;
56 PruneOptions(const Weight &weight_threshold, StateId state_threshold,
57 ArcFilter filter, std::vector<Weight> *distance = nullptr,
58 float delta = kDelta, bool threshold_initial = false)
59 : weight_threshold(std::move(weight_threshold)),
60 state_threshold(state_threshold),
61 filter(std::move(filter)),
64 threshold_initial(threshold_initial) {}
66 // Pruning weight threshold.
67 Weight weight_threshold;
68 // Pruning state threshold.
69 StateId state_threshold;
72 // If non-zero, passes in pre-computed shortest distance to final states.
73 const std::vector<Weight> *distance;
74 // Determines the degree of convergence required when computing shortest
77 // Determines if the shortest path weight is left (true) or right
78 // (false) multiplied by the threshold to get the limit for
79 // keeping a state or arc (matters if the semiring is not
81 bool threshold_initial;
84 // Pruning algorithm: this version modifies its input and it takes an options
85 // class as an argument. After pruning the FST contains states and arcs that
86 // belong to a successful path in the FST whose weight is no more than the
87 // weight of the shortest path Times() the provided weight threshold. When the
88 // state threshold is not kNoStateId, the output FST is further restricted to
89 // have no more than the number of states in opts.state_threshold. Weights must
90 // have the path property. The weight of any cycle needs to be bounded; i.e.,
92 // Plus(weight, Weight::One()) == Weight::One()
93 template <class Arc, class ArcFilter>
94 void Prune(MutableFst<Arc> *fst, const PruneOptions<Arc, ArcFilter> &opts) {
95 using StateId = typename Arc::StateId;
96 using Weight = typename Arc::Weight;
97 using StateHeap = Heap<StateId, internal::PruneCompare<StateId, Weight>>;
98 // TODO(kbg): Make this a compile-time static_assert once:
99 // 1) All weight properties are made constexpr for all weight types.
100 // 2) We have a pleasant way to "deregister" this operation for non-path
101 // semirings so an informative error message is produced. The best
102 // solution will probably involve some kind of SFINAE magic.
103 if ((Weight::Properties() & kPath) != kPath) {
104 FSTERROR() << "Prune: Weight needs to have the path property: "
106 fst->SetProperties(kError, kError);
109 auto ns = fst->NumStates();
111 std::vector<Weight> idistance(ns, Weight::Zero());
112 std::vector<Weight> tmp;
113 if (!opts.distance) {
115 ShortestDistance(*fst, &tmp, true, opts.delta);
117 const auto *fdistance = opts.distance ? opts.distance : &tmp;
118 if ((opts.state_threshold == 0) || (fdistance->size() <= fst->Start()) ||
119 ((*fdistance)[fst->Start()] == Weight::Zero())) {
123 internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
124 StateHeap heap(compare);
125 std::vector<bool> visited(ns, false);
126 std::vector<size_t> enqueued(ns, StateHeap::kNoKey);
127 std::vector<StateId> dead;
128 dead.push_back(fst->AddState());
129 NaturalLess<Weight> less;
130 auto s = fst->Start();
131 const auto limit = opts.threshold_initial ?
132 Times(opts.weight_threshold, (*fdistance)[s]) :
133 Times((*fdistance)[s], opts.weight_threshold);
134 StateId num_visited = 0;
136 if (!less(limit, (*fdistance)[s])) {
137 idistance[s] = Weight::One();
138 enqueued[s] = heap.Insert(s);
141 while (!heap.Empty()) {
144 enqueued[s] = StateHeap::kNoKey;
146 if (less(limit, Times(idistance[s], fst->Final(s)))) {
147 fst->SetFinal(s, Weight::Zero());
149 for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
151 auto arc = aiter.Value(); // Copy intended.
152 if (!opts.filter(arc)) continue;
153 const auto weight = Times(Times(idistance[s], arc.weight),
154 arc.nextstate < fdistance->size() ?
155 (*fdistance)[arc.nextstate] : Weight::Zero());
156 if (less(limit, weight)) {
157 arc.nextstate = dead[0];
161 if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
162 idistance[arc.nextstate] = Times(idistance[s], arc.weight);
164 if (visited[arc.nextstate]) continue;
165 if ((opts.state_threshold != kNoStateId) &&
166 (num_visited >= opts.state_threshold)) {
169 if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
170 enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
173 heap.Update(enqueued[arc.nextstate], arc.nextstate);
177 for (StateId i = 0; i < visited.size(); ++i) {
178 if (!visited[i]) dead.push_back(i);
180 fst->DeleteStates(dead);
183 // Pruning algorithm: this version modifies its input and takes the
184 // pruning threshold as an argument. It deletes states and arcs in the
185 // FST that do not belong to a successful path whose weight is more
186 // than the weight of the shortest path Times() the provided weight
187 // threshold. When the state threshold is not kNoStateId, the output
188 // FST is further restricted to have no more than the number of states
189 // in opts.state_threshold. Weights must have the path property. The
190 // weight of any cycle needs to be bounded; i.e.,
192 // Plus(weight, Weight::One()) == Weight::One()
194 void Prune(MutableFst<Arc> *fst, typename Arc::Weight weight_threshold,
195 typename Arc::StateId state_threshold = kNoStateId,
196 double delta = kDelta) {
197 const PruneOptions<Arc, AnyArcFilter<Arc>> opts(
198 weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
202 // Pruning algorithm: this version writes the pruned input FST to an
203 // output MutableFst and it takes an options class as an argument. The
204 // output FST contains states and arcs that belong to a successful
205 // path in the input FST whose weight is more than the weight of the
206 // shortest path Times() the provided weight threshold. When the state
207 // threshold is not kNoStateId, the output FST is further restricted
208 // to have no more than the number of states in
209 // opts.state_threshold. Weights have the path property. The weight
210 // of any cycle needs to be bounded; i.e.,
212 // Plus(weight, Weight::One()) == Weight::One()
213 template <class Arc, class ArcFilter>
214 void Prune(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
215 const PruneOptions<Arc, ArcFilter> &opts) {
216 using StateId = typename Arc::StateId;
217 using Weight = typename Arc::Weight;
218 using StateHeap = Heap<StateId, internal::PruneCompare<StateId, Weight>>;
219 // TODO(kbg): Make this a compile-time static_assert once:
220 // 1) All weight properties are made constexpr for all weight types.
221 // 2) We have a pleasant way to "deregister" this operation for non-path
222 // semirings so an informative error message is produced. The best
223 // solution will probably involve some kind of SFINAE magic.
224 if ((Weight::Properties() & kPath) != kPath) {
225 FSTERROR() << "Prune: Weight needs to have the path property: "
227 ofst->SetProperties(kError, kError);
230 ofst->DeleteStates();
231 ofst->SetInputSymbols(ifst.InputSymbols());
232 ofst->SetOutputSymbols(ifst.OutputSymbols());
233 if (ifst.Start() == kNoStateId) return;
234 NaturalLess<Weight> less;
235 if (less(opts.weight_threshold, Weight::One()) ||
236 (opts.state_threshold == 0)) {
239 std::vector<Weight> idistance;
240 std::vector<Weight> tmp;
241 if (!opts.distance) ShortestDistance(ifst, &tmp, true, opts.delta);
242 const auto *fdistance = opts.distance ? opts.distance : &tmp;
243 if ((fdistance->size() <= ifst.Start()) ||
244 ((*fdistance)[ifst.Start()] == Weight::Zero())) {
247 internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
248 StateHeap heap(compare);
249 std::vector<StateId> copy;
250 std::vector<size_t> enqueued;
251 std::vector<bool> visited;
252 auto s = ifst.Start();
253 const auto limit = opts.threshold_initial ?
254 Times(opts.weight_threshold, (*fdistance)[s]) :
255 Times((*fdistance)[s], opts.weight_threshold);
256 while (copy.size() <= s) copy.push_back(kNoStateId);
257 copy[s] = ofst->AddState();
258 ofst->SetStart(copy[s]);
259 while (idistance.size() <= s) idistance.push_back(Weight::Zero());
260 idistance[s] = Weight::One();
261 while (enqueued.size() <= s) {
262 enqueued.push_back(StateHeap::kNoKey);
263 visited.push_back(false);
265 enqueued[s] = heap.Insert(s);
266 while (!heap.Empty()) {
269 enqueued[s] = StateHeap::kNoKey;
271 if (!less(limit, Times(idistance[s], ifst.Final(s)))) {
272 ofst->SetFinal(copy[s], ifst.Final(s));
274 for (ArcIterator<Fst<Arc>> aiter(ifst, s); !aiter.Done(); aiter.Next()) {
275 const auto &arc = aiter.Value();
276 if (!opts.filter(arc)) continue;
277 const auto weight = Times(Times(idistance[s], arc.weight),
278 arc.nextstate < fdistance->size() ?
279 (*fdistance)[arc.nextstate] : Weight::Zero());
280 if (less(limit, weight)) continue;
281 if ((opts.state_threshold != kNoStateId) &&
282 (ofst->NumStates() >= opts.state_threshold)) {
285 while (idistance.size() <= arc.nextstate) {
286 idistance.push_back(Weight::Zero());
288 if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
289 idistance[arc.nextstate] = Times(idistance[s], arc.weight);
291 while (copy.size() <= arc.nextstate) copy.push_back(kNoStateId);
292 if (copy[arc.nextstate] == kNoStateId) {
293 copy[arc.nextstate] = ofst->AddState();
295 ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight,
296 copy[arc.nextstate]));
297 while (enqueued.size() <= arc.nextstate) {
298 enqueued.push_back(StateHeap::kNoKey);
299 visited.push_back(false);
301 if (visited[arc.nextstate]) continue;
302 if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
303 enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
305 heap.Update(enqueued[arc.nextstate], arc.nextstate);
311 // Pruning algorithm: this version writes the pruned input FST to an
312 // output MutableFst and simply takes the pruning threshold as an
313 // argument. The output FST contains states and arcs that belong to a
314 // successful path in the input FST whose weight is no more than the
315 // weight of the shortest path Times() the provided weight
316 // threshold. When the state threshold is not kNoStateId, the output
317 // FST is further restricted to have no more than the number of states
318 // in opts.state_threshold. Weights must have the path property. The
319 // weight of any cycle needs to be bounded; i.e.,
321 // Plus(weight, Weight::One()) = Weight::One();
323 void Prune(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
324 typename Arc::Weight weight_threshold,
325 typename Arc::StateId state_threshold = kNoStateId,
326 float delta = kDelta) {
327 const PruneOptions<Arc, AnyArcFilter<Arc>> opts(
328 weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
329 Prune(ifst, ofst, opts);
334 #endif // FST_LIB_PRUNE_H_