1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Functions implementing pruning.
6 #ifndef FST_LIB_PRUNE_H_
7 #define FST_LIB_PRUNE_H_
14 #include <fst/arcfilter.h>
16 #include <fst/shortest-distance.h>
22 template <class StateId, class Weight>
25 PruneCompare(const std::vector<Weight> &idistance,
26 const std::vector<Weight> &fdistance)
27 : idistance_(idistance), fdistance_(fdistance) {}
29 bool operator()(const StateId x, const StateId y) const {
30 const auto wx = Times(IDistance(x), FDistance(x));
31 const auto wy = Times(IDistance(y), FDistance(y));
36 Weight IDistance(const StateId s) const {
37 return s < idistance_.size() ? idistance_[s] : Weight::Zero();
40 Weight FDistance(const StateId s) const {
41 return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
44 const std::vector<Weight> &idistance_;
45 const std::vector<Weight> &fdistance_;
46 NaturalLess<Weight> less_;
49 } // namespace internal
51 template <class Arc, class ArcFilter>
53 using StateId = typename Arc::StateId;
54 using Weight = typename Arc::Weight;
56 PruneOptions(const Weight &weight_threshold, StateId state_threshold,
57 ArcFilter filter, std::vector<Weight> *distance = nullptr,
58 float delta = kDelta, bool threshold_initial = false)
59 : weight_threshold(std::move(weight_threshold)),
60 state_threshold(state_threshold),
61 filter(std::move(filter)),
64 threshold_initial(threshold_initial) {}
66 // Pruning weight threshold.
67 Weight weight_threshold;
68 // Pruning state threshold.
69 StateId state_threshold;
72 // If non-zero, passes in pre-computed shortest distance to final states.
73 const std::vector<Weight> *distance;
74 // Determines the degree of convergence required when computing shortest
77 // Determines if the shortest path weight is left (true) or right
78 // (false) multiplied by the threshold to get the limit for
79 // keeping a state or arc (matters if the semiring is not
81 bool threshold_initial;
84 // Pruning algorithm: this version modifies its input and it takes an options
85 // class as an argument. After pruning the FST contains states and arcs that
86 // belong to a successful path in the FST whose weight is no more than the
87 // weight of the shortest path Times() the provided weight threshold. When the
88 // state threshold is not kNoStateId, the output FST is further restricted to
89 // have no more than the number of states in opts.state_threshold. Weights must
90 // have the path property. The weight of any cycle needs to be bounded; i.e.,
92 // Plus(weight, Weight::One()) == Weight::One()
93 template <class Arc, class ArcFilter>
94 void Prune(MutableFst<Arc> *fst, const PruneOptions<Arc, ArcFilter> &opts) {
95 using StateId = typename Arc::StateId;
96 using Weight = typename Arc::Weight;
97 using StateHeap = Heap<StateId, internal::PruneCompare<StateId, Weight>>;
98 // TODO(kbg): Make this a compile-time static_assert once we have a pleasant
99 // way to "deregister" this operation for non-path semirings so an informative
100 // error message is produced.
101 if ((Weight::Properties() & kPath) != kPath) {
102 FSTERROR() << "Prune: Weight needs to have the path property: "
104 fst->SetProperties(kError, kError);
107 auto ns = fst->NumStates();
109 std::vector<Weight> idistance(ns, Weight::Zero());
110 std::vector<Weight> tmp;
111 if (!opts.distance) {
113 ShortestDistance(*fst, &tmp, true, opts.delta);
115 const auto *fdistance = opts.distance ? opts.distance : &tmp;
116 if ((opts.state_threshold == 0) || (fdistance->size() <= fst->Start()) ||
117 ((*fdistance)[fst->Start()] == Weight::Zero())) {
121 internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
122 StateHeap heap(compare);
123 std::vector<bool> visited(ns, false);
124 std::vector<size_t> enqueued(ns, StateHeap::kNoKey);
125 std::vector<StateId> dead;
126 dead.push_back(fst->AddState());
127 NaturalLess<Weight> less;
128 auto s = fst->Start();
129 const auto limit = opts.threshold_initial ?
130 Times(opts.weight_threshold, (*fdistance)[s]) :
131 Times((*fdistance)[s], opts.weight_threshold);
132 StateId num_visited = 0;
134 if (!less(limit, (*fdistance)[s])) {
135 idistance[s] = Weight::One();
136 enqueued[s] = heap.Insert(s);
139 while (!heap.Empty()) {
142 enqueued[s] = StateHeap::kNoKey;
144 if (less(limit, Times(idistance[s], fst->Final(s)))) {
145 fst->SetFinal(s, Weight::Zero());
147 for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
149 auto arc = aiter.Value(); // Copy intended.
150 if (!opts.filter(arc)) continue;
151 const auto weight = Times(Times(idistance[s], arc.weight),
152 arc.nextstate < fdistance->size() ?
153 (*fdistance)[arc.nextstate] : Weight::Zero());
154 if (less(limit, weight)) {
155 arc.nextstate = dead[0];
159 if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
160 idistance[arc.nextstate] = Times(idistance[s], arc.weight);
162 if (visited[arc.nextstate]) continue;
163 if ((opts.state_threshold != kNoStateId) &&
164 (num_visited >= opts.state_threshold)) {
167 if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
168 enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
171 heap.Update(enqueued[arc.nextstate], arc.nextstate);
175 for (StateId i = 0; i < visited.size(); ++i) {
176 if (!visited[i]) dead.push_back(i);
178 fst->DeleteStates(dead);
181 // Pruning algorithm: this version modifies its input and takes the
182 // pruning threshold as an argument. It deletes states and arcs in the
183 // FST that do not belong to a successful path whose weight is more
184 // than the weight of the shortest path Times() the provided weight
185 // threshold. When the state threshold is not kNoStateId, the output
186 // FST is further restricted to have no more than the number of states
187 // in opts.state_threshold. Weights must have the path property. The
188 // weight of any cycle needs to be bounded; i.e.,
190 // Plus(weight, Weight::One()) == Weight::One()
192 void Prune(MutableFst<Arc> *fst, typename Arc::Weight weight_threshold,
193 typename Arc::StateId state_threshold = kNoStateId,
194 double delta = kDelta) {
195 const PruneOptions<Arc, AnyArcFilter<Arc>> opts(
196 weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
200 // Pruning algorithm: this version writes the pruned input FST to an
201 // output MutableFst and it takes an options class as an argument. The
202 // output FST contains states and arcs that belong to a successful
203 // path in the input FST whose weight is more than the weight of the
204 // shortest path Times() the provided weight threshold. When the state
205 // threshold is not kNoStateId, the output FST is further restricted
206 // to have no more than the number of states in
207 // opts.state_threshold. Weights have the path property. The weight
208 // of any cycle needs to be bounded; i.e.,
210 // Plus(weight, Weight::One()) == Weight::One()
211 template <class Arc, class ArcFilter>
212 void Prune(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
213 const PruneOptions<Arc, ArcFilter> &opts) {
214 using StateId = typename Arc::StateId;
215 using Weight = typename Arc::Weight;
216 using StateHeap = Heap<StateId, internal::PruneCompare<StateId, Weight>>;
217 // TODO(kbg): Make this a compile-time static_assert once we have a pleasant
218 // way to "deregister" this operation for non-path semirings so an informative
219 // error message is produced.
220 if ((Weight::Properties() & kPath) != kPath) {
221 FSTERROR() << "Prune: Weight needs to have the path property: "
223 ofst->SetProperties(kError, kError);
226 ofst->DeleteStates();
227 ofst->SetInputSymbols(ifst.InputSymbols());
228 ofst->SetOutputSymbols(ifst.OutputSymbols());
229 if (ifst.Start() == kNoStateId) return;
230 NaturalLess<Weight> less;
231 if (less(opts.weight_threshold, Weight::One()) ||
232 (opts.state_threshold == 0)) {
235 std::vector<Weight> idistance;
236 std::vector<Weight> tmp;
237 if (!opts.distance) ShortestDistance(ifst, &tmp, true, opts.delta);
238 const auto *fdistance = opts.distance ? opts.distance : &tmp;
239 if ((fdistance->size() <= ifst.Start()) ||
240 ((*fdistance)[ifst.Start()] == Weight::Zero())) {
243 internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
244 StateHeap heap(compare);
245 std::vector<StateId> copy;
246 std::vector<size_t> enqueued;
247 std::vector<bool> visited;
248 auto s = ifst.Start();
249 const auto limit = opts.threshold_initial ?
250 Times(opts.weight_threshold, (*fdistance)[s]) :
251 Times((*fdistance)[s], opts.weight_threshold);
252 while (copy.size() <= s) copy.push_back(kNoStateId);
253 copy[s] = ofst->AddState();
254 ofst->SetStart(copy[s]);
255 while (idistance.size() <= s) idistance.push_back(Weight::Zero());
256 idistance[s] = Weight::One();
257 while (enqueued.size() <= s) {
258 enqueued.push_back(StateHeap::kNoKey);
259 visited.push_back(false);
261 enqueued[s] = heap.Insert(s);
262 while (!heap.Empty()) {
265 enqueued[s] = StateHeap::kNoKey;
267 if (!less(limit, Times(idistance[s], ifst.Final(s)))) {
268 ofst->SetFinal(copy[s], ifst.Final(s));
270 for (ArcIterator<Fst<Arc>> aiter(ifst, s); !aiter.Done(); aiter.Next()) {
271 const auto &arc = aiter.Value();
272 if (!opts.filter(arc)) continue;
273 const auto weight = Times(Times(idistance[s], arc.weight),
274 arc.nextstate < fdistance->size() ?
275 (*fdistance)[arc.nextstate] : Weight::Zero());
276 if (less(limit, weight)) continue;
277 if ((opts.state_threshold != kNoStateId) &&
278 (ofst->NumStates() >= opts.state_threshold)) {
281 while (idistance.size() <= arc.nextstate) {
282 idistance.push_back(Weight::Zero());
284 if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
285 idistance[arc.nextstate] = Times(idistance[s], arc.weight);
287 while (copy.size() <= arc.nextstate) copy.push_back(kNoStateId);
288 if (copy[arc.nextstate] == kNoStateId) {
289 copy[arc.nextstate] = ofst->AddState();
291 ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight,
292 copy[arc.nextstate]));
293 while (enqueued.size() <= arc.nextstate) {
294 enqueued.push_back(StateHeap::kNoKey);
295 visited.push_back(false);
297 if (visited[arc.nextstate]) continue;
298 if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
299 enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
301 heap.Update(enqueued[arc.nextstate], arc.nextstate);
307 // Pruning algorithm: this version writes the pruned input FST to an
308 // output MutableFst and simply takes the pruning threshold as an
309 // argument. The output FST contains states and arcs that belong to a
310 // successful path in the input FST whose weight is no more than the
311 // weight of the shortest path Times() the provided weight
312 // threshold. When the state threshold is not kNoStateId, the output
313 // FST is further restricted to have no more than the number of states
314 // in opts.state_threshold. Weights must have the path property. The
315 // weight of any cycle needs to be bounded; i.e.,
317 // Plus(weight, Weight::One()) = Weight::One();
319 void Prune(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
320 typename Arc::Weight weight_threshold,
321 typename Arc::StateId state_threshold = kNoStateId,
322 float delta = kDelta) {
323 const PruneOptions<Arc, AnyArcFilter<Arc>> opts(
324 weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
325 Prune(ifst, ofst, opts);
330 #endif // FST_LIB_PRUNE_H_