1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Functions implementing pruning.
15 #include <fst/arcfilter.h>
17 #include <fst/shortest-distance.h>
23 template <class StateId, class Weight>
26 PruneCompare(const std::vector<Weight> &idistance,
27 const std::vector<Weight> &fdistance)
28 : idistance_(idistance), fdistance_(fdistance) {}
30 bool operator()(const StateId x, const StateId y) const {
31 const auto wx = Times(IDistance(x), FDistance(x));
32 const auto wy = Times(IDistance(y), FDistance(y));
37 Weight IDistance(const StateId s) const {
38 return s < idistance_.size() ? idistance_[s] : Weight::Zero();
41 Weight FDistance(const StateId s) const {
42 return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
45 const std::vector<Weight> &idistance_;
46 const std::vector<Weight> &fdistance_;
47 NaturalLess<Weight> less_;
50 } // namespace internal
52 template <class Arc, class ArcFilter>
54 using StateId = typename Arc::StateId;
55 using Weight = typename Arc::Weight;
57 PruneOptions(const Weight &weight_threshold, StateId state_threshold,
58 ArcFilter filter, std::vector<Weight> *distance = nullptr,
59 float delta = kDelta, bool threshold_initial = false)
60 : weight_threshold(std::move(weight_threshold)),
61 state_threshold(state_threshold),
62 filter(std::move(filter)),
65 threshold_initial(threshold_initial) {}
67 // Pruning weight threshold.
68 Weight weight_threshold;
69 // Pruning state threshold.
70 StateId state_threshold;
73 // If non-zero, passes in pre-computed shortest distance to final states.
74 const std::vector<Weight> *distance;
75 // Determines the degree of convergence required when computing shortest
78 // Determines if the shortest path weight is left (true) or right
79 // (false) multiplied by the threshold to get the limit for
80 // keeping a state or arc (matters if the semiring is not
82 bool threshold_initial;
85 // Pruning algorithm: this version modifies its input and it takes an options
86 // class as an argument. After pruning the FST contains states and arcs that
87 // belong to a successful path in the FST whose weight is no more than the
88 // weight of the shortest path Times() the provided weight threshold. When the
89 // state threshold is not kNoStateId, the output FST is further restricted to
90 // have no more than the number of states in opts.state_threshold. Weights must
91 // have the path property. The weight of any cycle needs to be bounded; i.e.,
93 // Plus(weight, Weight::One()) == Weight::One()
94 template <class Arc, class ArcFilter,
95 typename std::enable_if<
96 (Arc::Weight::Properties() & kPath) == kPath>::type * = nullptr>
97 void Prune(MutableFst<Arc> *fst, const PruneOptions<Arc, ArcFilter> &opts) {
98 using StateId = typename Arc::StateId;
99 using Weight = typename Arc::Weight;
100 using StateHeap = Heap<StateId, internal::PruneCompare<StateId, Weight>>;
101 auto ns = fst->NumStates();
103 std::vector<Weight> idistance(ns, Weight::Zero());
104 std::vector<Weight> tmp;
105 if (!opts.distance) {
107 ShortestDistance(*fst, &tmp, true, opts.delta);
109 const auto *fdistance = opts.distance ? opts.distance : &tmp;
110 if ((opts.state_threshold == 0) || (fdistance->size() <= fst->Start()) ||
111 ((*fdistance)[fst->Start()] == Weight::Zero())) {
115 internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
116 StateHeap heap(compare);
117 std::vector<bool> visited(ns, false);
118 std::vector<size_t> enqueued(ns, StateHeap::kNoKey);
119 std::vector<StateId> dead;
120 dead.push_back(fst->AddState());
121 NaturalLess<Weight> less;
122 auto s = fst->Start();
123 const auto limit = opts.threshold_initial ?
124 Times(opts.weight_threshold, (*fdistance)[s]) :
125 Times((*fdistance)[s], opts.weight_threshold);
126 StateId num_visited = 0;
128 if (!less(limit, (*fdistance)[s])) {
129 idistance[s] = Weight::One();
130 enqueued[s] = heap.Insert(s);
133 while (!heap.Empty()) {
136 enqueued[s] = StateHeap::kNoKey;
138 if (less(limit, Times(idistance[s], fst->Final(s)))) {
139 fst->SetFinal(s, Weight::Zero());
141 for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
143 auto arc = aiter.Value(); // Copy intended.
144 if (!opts.filter(arc)) continue;
145 const auto weight = Times(Times(idistance[s], arc.weight),
146 arc.nextstate < fdistance->size() ?
147 (*fdistance)[arc.nextstate] : Weight::Zero());
148 if (less(limit, weight)) {
149 arc.nextstate = dead[0];
153 if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
154 idistance[arc.nextstate] = Times(idistance[s], arc.weight);
156 if (visited[arc.nextstate]) continue;
157 if ((opts.state_threshold != kNoStateId) &&
158 (num_visited >= opts.state_threshold)) {
161 if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
162 enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
165 heap.Update(enqueued[arc.nextstate], arc.nextstate);
169 for (StateId i = 0; i < visited.size(); ++i) {
170 if (!visited[i]) dead.push_back(i);
172 fst->DeleteStates(dead);
175 template <class Arc, class ArcFilter,
176 typename std::enable_if<
177 (Arc::Weight::Properties() & kPath) != kPath>::type * = nullptr>
178 void Prune(MutableFst<Arc> *fst, const PruneOptions<Arc, ArcFilter> &) {
179 FSTERROR() << "Prune: Weight needs to have the path property: "
180 << Arc::Weight::Type();
181 fst->SetProperties(kError, kError);
184 // Pruning algorithm: this version modifies its input and takes the
185 // pruning threshold as an argument. It deletes states and arcs in the
186 // FST that do not belong to a successful path whose weight is more
187 // than the weight of the shortest path Times() the provided weight
188 // threshold. When the state threshold is not kNoStateId, the output
189 // FST is further restricted to have no more than the number of states
190 // in opts.state_threshold. Weights must have the path property. The
191 // weight of any cycle needs to be bounded; i.e.,
193 // Plus(weight, Weight::One()) == Weight::One()
195 void Prune(MutableFst<Arc> *fst, typename Arc::Weight weight_threshold,
196 typename Arc::StateId state_threshold = kNoStateId,
197 float delta = kDelta) {
198 const PruneOptions<Arc, AnyArcFilter<Arc>> opts(
199 weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
203 // Pruning algorithm: this version writes the pruned input FST to an
204 // output MutableFst and it takes an options class as an argument. The
205 // output FST contains states and arcs that belong to a successful
206 // path in the input FST whose weight is more than the weight of the
207 // shortest path Times() the provided weight threshold. When the state
208 // threshold is not kNoStateId, the output FST is further restricted
209 // to have no more than the number of states in
210 // opts.state_threshold. Weights have the path property. The weight
211 // of any cycle needs to be bounded; i.e.,
213 // Plus(weight, Weight::One()) == Weight::One()
214 template <class Arc, class ArcFilter,
215 typename std::enable_if<
216 (Arc::Weight::Properties() & kPath) == kPath>::type * = nullptr>
217 void Prune(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
218 const PruneOptions<Arc, ArcFilter> &opts) {
219 using StateId = typename Arc::StateId;
220 using Weight = typename Arc::Weight;
221 using StateHeap = Heap<StateId, internal::PruneCompare<StateId, Weight>>;
222 ofst->DeleteStates();
223 ofst->SetInputSymbols(ifst.InputSymbols());
224 ofst->SetOutputSymbols(ifst.OutputSymbols());
225 if (ifst.Start() == kNoStateId) return;
226 NaturalLess<Weight> less;
227 if (less(opts.weight_threshold, Weight::One()) ||
228 (opts.state_threshold == 0)) {
231 std::vector<Weight> idistance;
232 std::vector<Weight> tmp;
233 if (!opts.distance) ShortestDistance(ifst, &tmp, true, opts.delta);
234 const auto *fdistance = opts.distance ? opts.distance : &tmp;
235 if ((fdistance->size() <= ifst.Start()) ||
236 ((*fdistance)[ifst.Start()] == Weight::Zero())) {
239 internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
240 StateHeap heap(compare);
241 std::vector<StateId> copy;
242 std::vector<size_t> enqueued;
243 std::vector<bool> visited;
244 auto s = ifst.Start();
245 const auto limit = opts.threshold_initial ?
246 Times(opts.weight_threshold, (*fdistance)[s]) :
247 Times((*fdistance)[s], opts.weight_threshold);
248 while (copy.size() <= s) copy.push_back(kNoStateId);
249 copy[s] = ofst->AddState();
250 ofst->SetStart(copy[s]);
251 while (idistance.size() <= s) idistance.push_back(Weight::Zero());
252 idistance[s] = Weight::One();
253 while (enqueued.size() <= s) {
254 enqueued.push_back(StateHeap::kNoKey);
255 visited.push_back(false);
257 enqueued[s] = heap.Insert(s);
258 while (!heap.Empty()) {
261 enqueued[s] = StateHeap::kNoKey;
263 if (!less(limit, Times(idistance[s], ifst.Final(s)))) {
264 ofst->SetFinal(copy[s], ifst.Final(s));
266 for (ArcIterator<Fst<Arc>> aiter(ifst, s); !aiter.Done(); aiter.Next()) {
267 const auto &arc = aiter.Value();
268 if (!opts.filter(arc)) continue;
269 const auto weight = Times(Times(idistance[s], arc.weight),
270 arc.nextstate < fdistance->size() ?
271 (*fdistance)[arc.nextstate] : Weight::Zero());
272 if (less(limit, weight)) continue;
273 if ((opts.state_threshold != kNoStateId) &&
274 (ofst->NumStates() >= opts.state_threshold)) {
277 while (idistance.size() <= arc.nextstate) {
278 idistance.push_back(Weight::Zero());
280 if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
281 idistance[arc.nextstate] = Times(idistance[s], arc.weight);
283 while (copy.size() <= arc.nextstate) copy.push_back(kNoStateId);
284 if (copy[arc.nextstate] == kNoStateId) {
285 copy[arc.nextstate] = ofst->AddState();
287 ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight,
288 copy[arc.nextstate]));
289 while (enqueued.size() <= arc.nextstate) {
290 enqueued.push_back(StateHeap::kNoKey);
291 visited.push_back(false);
293 if (visited[arc.nextstate]) continue;
294 if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
295 enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
297 heap.Update(enqueued[arc.nextstate], arc.nextstate);
303 template <class Arc, class ArcFilter,
304 typename std::enable_if<
305 (Arc::Weight::Properties() & kPath) != kPath>::type * = nullptr>
306 void Prune(const Fst<Arc> &, MutableFst<Arc> *ofst,
307 const PruneOptions<Arc, ArcFilter> &) {
308 FSTERROR() << "Prune: Weight needs to have the path property: "
309 << Arc::Weight::Type();
310 ofst->SetProperties(kError, kError);
313 // Pruning algorithm: this version writes the pruned input FST to an
314 // output MutableFst and simply takes the pruning threshold as an
315 // argument. The output FST contains states and arcs that belong to a
316 // successful path in the input FST whose weight is no more than the
317 // weight of the shortest path Times() the provided weight
318 // threshold. When the state threshold is not kNoStateId, the output
319 // FST is further restricted to have no more than the number of states
320 // in opts.state_threshold. Weights must have the path property. The
321 // weight of any cycle needs to be bounded; i.e.,
323 // Plus(weight, Weight::One()) = Weight::One();
325 void Prune(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
326 typename Arc::Weight weight_threshold,
327 typename Arc::StateId state_threshold = kNoStateId,
328 float delta = kDelta) {
329 const PruneOptions<Arc, AnyArcFilter<Arc>> opts(
330 weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
331 Prune(ifst, ofst, opts);
336 #endif // FST_PRUNE_H_