Imported Upstream version 1.6.4
[platform/upstream/openfst.git] / src / include / fst / prune.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions implementing pruning.
5
6 #ifndef FST_LIB_PRUNE_H_
7 #define FST_LIB_PRUNE_H_
8
9 #include <utility>
10 #include <vector>
11
12 #include <fst/log.h>
13
14 #include <fst/arcfilter.h>
15 #include <fst/heap.h>
16 #include <fst/shortest-distance.h>
17
18
19 namespace fst {
20 namespace internal {
21
22 template <class StateId, class Weight>
23 class PruneCompare {
24  public:
25   PruneCompare(const std::vector<Weight> &idistance,
26                const std::vector<Weight> &fdistance)
27       : idistance_(idistance), fdistance_(fdistance) {}
28
29   bool operator()(const StateId x, const StateId y) const {
30     const auto wx = Times(IDistance(x), FDistance(x));
31     const auto wy = Times(IDistance(y), FDistance(y));
32     return less_(wx, wy);
33   }
34
35  private:
36   Weight IDistance(const StateId s) const {
37     return s < idistance_.size() ? idistance_[s] : Weight::Zero();
38   }
39
40   Weight FDistance(const StateId s) const {
41     return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
42   }
43
44   const std::vector<Weight> &idistance_;
45   const std::vector<Weight> &fdistance_;
46   NaturalLess<Weight> less_;
47 };
48
49 }  // namespace internal
50
51 template <class Arc, class ArcFilter>
52 struct PruneOptions {
53   using StateId = typename Arc::StateId;
54   using Weight = typename Arc::Weight;
55
56   PruneOptions(const Weight &weight_threshold, StateId state_threshold,
57                ArcFilter filter, std::vector<Weight> *distance = nullptr,
58                float delta = kDelta, bool threshold_initial = false)
59       : weight_threshold(std::move(weight_threshold)),
60         state_threshold(state_threshold),
61         filter(std::move(filter)),
62         distance(distance),
63         delta(delta),
64         threshold_initial(threshold_initial) {}
65
66   // Pruning weight threshold.
67   Weight weight_threshold;
68   // Pruning state threshold.
69   StateId state_threshold;
70   // Arc filter.
71   ArcFilter filter;
72   // If non-zero, passes in pre-computed shortest distance to final states.
73   const std::vector<Weight> *distance;
74   // Determines the degree of convergence required when computing shortest
75   // distances.
76   float delta;
77   // Determines if the shortest path weight is left (true) or right
78   // (false) multiplied by the threshold to get the limit for
79   // keeping a state or arc (matters if the semiring is not
80   // commutative).
81   bool threshold_initial;
82 };
83
84 // Pruning algorithm: this version modifies its input and it takes an options
85 // class as an argument. After pruning the FST contains states and arcs that
86 // belong to a successful path in the FST whose weight is no more than the
87 // weight of the shortest path Times() the provided weight threshold. When the
88 // state threshold is not kNoStateId, the output FST is further restricted to
89 // have no more than the number of states in opts.state_threshold. Weights must
90 // have the path property. The weight of any cycle needs to be bounded; i.e.,
91 //
92 //   Plus(weight, Weight::One()) == Weight::One()
93 template <class Arc, class ArcFilter>
94 void Prune(MutableFst<Arc> *fst, const PruneOptions<Arc, ArcFilter> &opts) {
95   using StateId = typename Arc::StateId;
96   using Weight = typename Arc::Weight;
97   using StateHeap = Heap<StateId, internal::PruneCompare<StateId, Weight>>;
98   // TODO(kbg): Make this a compile-time static_assert once we have a pleasant
99   // way to "deregister" this operation for non-path semirings so an informative
100   // error message is produced.
101   if ((Weight::Properties() & kPath) != kPath) {
102     FSTERROR() << "Prune: Weight needs to have the path property: "
103                << Weight::Type();
104     fst->SetProperties(kError, kError);
105     return;
106   }
107   auto ns = fst->NumStates();
108   if (ns < 1) return;
109   std::vector<Weight> idistance(ns, Weight::Zero());
110   std::vector<Weight> tmp;
111   if (!opts.distance) {
112     tmp.reserve(ns);
113     ShortestDistance(*fst, &tmp, true, opts.delta);
114   }
115   const auto *fdistance = opts.distance ? opts.distance : &tmp;
116   if ((opts.state_threshold == 0) || (fdistance->size() <= fst->Start()) ||
117       ((*fdistance)[fst->Start()] == Weight::Zero())) {
118     fst->DeleteStates();
119     return;
120   }
121   internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
122   StateHeap heap(compare);
123   std::vector<bool> visited(ns, false);
124   std::vector<size_t> enqueued(ns, StateHeap::kNoKey);
125   std::vector<StateId> dead;
126   dead.push_back(fst->AddState());
127   NaturalLess<Weight> less;
128   auto s = fst->Start();
129   const auto limit = opts.threshold_initial ?
130       Times(opts.weight_threshold, (*fdistance)[s]) :
131       Times((*fdistance)[s], opts.weight_threshold);
132   StateId num_visited = 0;
133
134   if (!less(limit, (*fdistance)[s])) {
135     idistance[s] = Weight::One();
136     enqueued[s] = heap.Insert(s);
137     ++num_visited;
138   }
139   while (!heap.Empty()) {
140     s = heap.Top();
141     heap.Pop();
142     enqueued[s] = StateHeap::kNoKey;
143     visited[s] = true;
144     if (less(limit, Times(idistance[s], fst->Final(s)))) {
145       fst->SetFinal(s, Weight::Zero());
146     }
147     for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
148          aiter.Next()) {
149       auto arc = aiter.Value();  // Copy intended.
150       if (!opts.filter(arc)) continue;
151       const auto weight = Times(Times(idistance[s], arc.weight),
152                                 arc.nextstate < fdistance->size() ?
153                                 (*fdistance)[arc.nextstate] : Weight::Zero());
154       if (less(limit, weight)) {
155         arc.nextstate = dead[0];
156         aiter.SetValue(arc);
157         continue;
158       }
159       if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
160         idistance[arc.nextstate] = Times(idistance[s], arc.weight);
161       }
162       if (visited[arc.nextstate]) continue;
163       if ((opts.state_threshold != kNoStateId) &&
164           (num_visited >= opts.state_threshold)) {
165         continue;
166       }
167       if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
168         enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
169         ++num_visited;
170       } else {
171         heap.Update(enqueued[arc.nextstate], arc.nextstate);
172       }
173     }
174   }
175   for (StateId i = 0; i < visited.size(); ++i) {
176     if (!visited[i]) dead.push_back(i);
177   }
178   fst->DeleteStates(dead);
179 }
180
181 // Pruning algorithm: this version modifies its input and takes the
182 // pruning threshold as an argument. It deletes states and arcs in the
183 // FST that do not belong to a successful path whose weight is more
184 // than the weight of the shortest path Times() the provided weight
185 // threshold. When the state threshold is not kNoStateId, the output
186 // FST is further restricted to have no more than the number of states
187 // in opts.state_threshold. Weights must have the path property. The
188 // weight of any cycle needs to be bounded; i.e.,
189 //
190 //   Plus(weight, Weight::One()) == Weight::One()
191 template <class Arc>
192 void Prune(MutableFst<Arc> *fst, typename Arc::Weight weight_threshold,
193            typename Arc::StateId state_threshold = kNoStateId,
194            double delta = kDelta) {
195   const PruneOptions<Arc, AnyArcFilter<Arc>> opts(
196       weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
197   Prune(fst, opts);
198 }
199
200 // Pruning algorithm: this version writes the pruned input FST to an
201 // output MutableFst and it takes an options class as an argument. The
202 // output FST contains states and arcs that belong to a successful
203 // path in the input FST whose weight is more than the weight of the
204 // shortest path Times() the provided weight threshold. When the state
205 // threshold is not kNoStateId, the output FST is further restricted
206 // to have no more than the number of states in
207 // opts.state_threshold. Weights have the path property.  The weight
208 // of any cycle needs to be bounded; i.e.,
209 //
210 //   Plus(weight, Weight::One()) == Weight::One()
211 template <class Arc, class ArcFilter>
212 void Prune(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
213            const PruneOptions<Arc, ArcFilter> &opts) {
214   using StateId = typename Arc::StateId;
215   using Weight = typename Arc::Weight;
216   using StateHeap = Heap<StateId, internal::PruneCompare<StateId, Weight>>;
217   // TODO(kbg): Make this a compile-time static_assert once we have a pleasant
218   // way to "deregister" this operation for non-path semirings so an informative
219   // error message is produced.
220   if ((Weight::Properties() & kPath) != kPath) {
221     FSTERROR() << "Prune: Weight needs to have the path property: "
222                << Weight::Type();
223     ofst->SetProperties(kError, kError);
224     return;
225   }
226   ofst->DeleteStates();
227   ofst->SetInputSymbols(ifst.InputSymbols());
228   ofst->SetOutputSymbols(ifst.OutputSymbols());
229   if (ifst.Start() == kNoStateId) return;
230   NaturalLess<Weight> less;
231   if (less(opts.weight_threshold, Weight::One()) ||
232       (opts.state_threshold == 0)) {
233     return;
234   }
235   std::vector<Weight> idistance;
236   std::vector<Weight> tmp;
237   if (!opts.distance) ShortestDistance(ifst, &tmp, true, opts.delta);
238   const auto *fdistance = opts.distance ? opts.distance : &tmp;
239   if ((fdistance->size() <= ifst.Start()) ||
240       ((*fdistance)[ifst.Start()] == Weight::Zero())) {
241     return;
242   }
243   internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
244   StateHeap heap(compare);
245   std::vector<StateId> copy;
246   std::vector<size_t> enqueued;
247   std::vector<bool> visited;
248   auto s = ifst.Start();
249   const auto limit = opts.threshold_initial ?
250       Times(opts.weight_threshold, (*fdistance)[s]) :
251       Times((*fdistance)[s], opts.weight_threshold);
252   while (copy.size() <= s) copy.push_back(kNoStateId);
253   copy[s] = ofst->AddState();
254   ofst->SetStart(copy[s]);
255   while (idistance.size() <= s) idistance.push_back(Weight::Zero());
256   idistance[s] = Weight::One();
257   while (enqueued.size() <= s) {
258     enqueued.push_back(StateHeap::kNoKey);
259     visited.push_back(false);
260   }
261   enqueued[s] = heap.Insert(s);
262   while (!heap.Empty()) {
263     s = heap.Top();
264     heap.Pop();
265     enqueued[s] = StateHeap::kNoKey;
266     visited[s] = true;
267     if (!less(limit, Times(idistance[s], ifst.Final(s)))) {
268       ofst->SetFinal(copy[s], ifst.Final(s));
269     }
270     for (ArcIterator<Fst<Arc>> aiter(ifst, s); !aiter.Done(); aiter.Next()) {
271       const auto &arc = aiter.Value();
272       if (!opts.filter(arc)) continue;
273       const auto weight = Times(Times(idistance[s], arc.weight),
274                                 arc.nextstate < fdistance->size() ?
275                                 (*fdistance)[arc.nextstate] : Weight::Zero());
276       if (less(limit, weight)) continue;
277       if ((opts.state_threshold != kNoStateId) &&
278           (ofst->NumStates() >= opts.state_threshold)) {
279         continue;
280       }
281       while (idistance.size() <= arc.nextstate) {
282         idistance.push_back(Weight::Zero());
283       }
284       if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
285         idistance[arc.nextstate] = Times(idistance[s], arc.weight);
286       }
287       while (copy.size() <= arc.nextstate) copy.push_back(kNoStateId);
288       if (copy[arc.nextstate] == kNoStateId) {
289         copy[arc.nextstate] = ofst->AddState();
290       }
291       ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight,
292                                 copy[arc.nextstate]));
293       while (enqueued.size() <= arc.nextstate) {
294         enqueued.push_back(StateHeap::kNoKey);
295         visited.push_back(false);
296       }
297       if (visited[arc.nextstate]) continue;
298       if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
299         enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
300       } else {
301         heap.Update(enqueued[arc.nextstate], arc.nextstate);
302       }
303     }
304   }
305 }
306
307 // Pruning algorithm: this version writes the pruned input FST to an
308 // output MutableFst and simply takes the pruning threshold as an
309 // argument. The output FST contains states and arcs that belong to a
310 // successful path in the input FST whose weight is no more than the
311 // weight of the shortest path Times() the provided weight
312 // threshold. When the state threshold is not kNoStateId, the output
313 // FST is further restricted to have no more than the number of states
314 // in opts.state_threshold. Weights must have the path property. The
315 // weight of any cycle needs to be bounded; i.e.,
316 //
317 // Plus(weight, Weight::One()) = Weight::One();
318 template <class Arc>
319 void Prune(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
320            typename Arc::Weight weight_threshold,
321            typename Arc::StateId state_threshold = kNoStateId,
322            float delta = kDelta) {
323   const PruneOptions<Arc, AnyArcFilter<Arc>> opts(
324       weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
325   Prune(ifst, ofst, opts);
326 }
327
328 }  // namespace fst
329
330 #endif  // FST_LIB_PRUNE_H_