Imported Upstream version 1.6.6
[platform/upstream/openfst.git] / src / include / fst / prune.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions implementing pruning.
5
6 #ifndef FST_PRUNE_H_
7 #define FST_PRUNE_H_
8
9 #include <type_traits>
10 #include <utility>
11 #include <vector>
12
13 #include <fst/log.h>
14
15 #include <fst/arcfilter.h>
16 #include <fst/heap.h>
17 #include <fst/shortest-distance.h>
18
19
20 namespace fst {
21 namespace internal {
22
23 template <class StateId, class Weight>
24 class PruneCompare {
25  public:
26   PruneCompare(const std::vector<Weight> &idistance,
27                const std::vector<Weight> &fdistance)
28       : idistance_(idistance), fdistance_(fdistance) {}
29
30   bool operator()(const StateId x, const StateId y) const {
31     const auto wx = Times(IDistance(x), FDistance(x));
32     const auto wy = Times(IDistance(y), FDistance(y));
33     return less_(wx, wy);
34   }
35
36  private:
37   Weight IDistance(const StateId s) const {
38     return s < idistance_.size() ? idistance_[s] : Weight::Zero();
39   }
40
41   Weight FDistance(const StateId s) const {
42     return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
43   }
44
45   const std::vector<Weight> &idistance_;
46   const std::vector<Weight> &fdistance_;
47   NaturalLess<Weight> less_;
48 };
49
50 }  // namespace internal
51
52 template <class Arc, class ArcFilter>
53 struct PruneOptions {
54   using StateId = typename Arc::StateId;
55   using Weight = typename Arc::Weight;
56
57   PruneOptions(const Weight &weight_threshold, StateId state_threshold,
58                ArcFilter filter, std::vector<Weight> *distance = nullptr,
59                float delta = kDelta, bool threshold_initial = false)
60       : weight_threshold(std::move(weight_threshold)),
61         state_threshold(state_threshold),
62         filter(std::move(filter)),
63         distance(distance),
64         delta(delta),
65         threshold_initial(threshold_initial) {}
66
67   // Pruning weight threshold.
68   Weight weight_threshold;
69   // Pruning state threshold.
70   StateId state_threshold;
71   // Arc filter.
72   ArcFilter filter;
73   // If non-zero, passes in pre-computed shortest distance to final states.
74   const std::vector<Weight> *distance;
75   // Determines the degree of convergence required when computing shortest
76   // distances.
77   float delta;
78   // Determines if the shortest path weight is left (true) or right
79   // (false) multiplied by the threshold to get the limit for
80   // keeping a state or arc (matters if the semiring is not
81   // commutative).
82   bool threshold_initial;
83 };
84
85 // Pruning algorithm: this version modifies its input and it takes an options
86 // class as an argument. After pruning the FST contains states and arcs that
87 // belong to a successful path in the FST whose weight is no more than the
88 // weight of the shortest path Times() the provided weight threshold. When the
89 // state threshold is not kNoStateId, the output FST is further restricted to
90 // have no more than the number of states in opts.state_threshold. Weights must
91 // have the path property. The weight of any cycle needs to be bounded; i.e.,
92 //
93 //   Plus(weight, Weight::One()) == Weight::One()
94 template <class Arc, class ArcFilter,
95           typename std::enable_if<
96               (Arc::Weight::Properties() & kPath) == kPath>::type * = nullptr>
97 void Prune(MutableFst<Arc> *fst, const PruneOptions<Arc, ArcFilter> &opts) {
98   using StateId = typename Arc::StateId;
99   using Weight = typename Arc::Weight;
100   using StateHeap = Heap<StateId, internal::PruneCompare<StateId, Weight>>;
101   auto ns = fst->NumStates();
102   if (ns < 1) return;
103   std::vector<Weight> idistance(ns, Weight::Zero());
104   std::vector<Weight> tmp;
105   if (!opts.distance) {
106     tmp.reserve(ns);
107     ShortestDistance(*fst, &tmp, true, opts.delta);
108   }
109   const auto *fdistance = opts.distance ? opts.distance : &tmp;
110   if ((opts.state_threshold == 0) || (fdistance->size() <= fst->Start()) ||
111       ((*fdistance)[fst->Start()] == Weight::Zero())) {
112     fst->DeleteStates();
113     return;
114   }
115   internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
116   StateHeap heap(compare);
117   std::vector<bool> visited(ns, false);
118   std::vector<size_t> enqueued(ns, StateHeap::kNoKey);
119   std::vector<StateId> dead;
120   dead.push_back(fst->AddState());
121   NaturalLess<Weight> less;
122   auto s = fst->Start();
123   const auto limit = opts.threshold_initial ?
124       Times(opts.weight_threshold, (*fdistance)[s]) :
125       Times((*fdistance)[s], opts.weight_threshold);
126   StateId num_visited = 0;
127
128   if (!less(limit, (*fdistance)[s])) {
129     idistance[s] = Weight::One();
130     enqueued[s] = heap.Insert(s);
131     ++num_visited;
132   }
133   while (!heap.Empty()) {
134     s = heap.Top();
135     heap.Pop();
136     enqueued[s] = StateHeap::kNoKey;
137     visited[s] = true;
138     if (less(limit, Times(idistance[s], fst->Final(s)))) {
139       fst->SetFinal(s, Weight::Zero());
140     }
141     for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
142          aiter.Next()) {
143       auto arc = aiter.Value();  // Copy intended.
144       if (!opts.filter(arc)) continue;
145       const auto weight = Times(Times(idistance[s], arc.weight),
146                                 arc.nextstate < fdistance->size() ?
147                                 (*fdistance)[arc.nextstate] : Weight::Zero());
148       if (less(limit, weight)) {
149         arc.nextstate = dead[0];
150         aiter.SetValue(arc);
151         continue;
152       }
153       if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
154         idistance[arc.nextstate] = Times(idistance[s], arc.weight);
155       }
156       if (visited[arc.nextstate]) continue;
157       if ((opts.state_threshold != kNoStateId) &&
158           (num_visited >= opts.state_threshold)) {
159         continue;
160       }
161       if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
162         enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
163         ++num_visited;
164       } else {
165         heap.Update(enqueued[arc.nextstate], arc.nextstate);
166       }
167     }
168   }
169   for (StateId i = 0; i < visited.size(); ++i) {
170     if (!visited[i]) dead.push_back(i);
171   }
172   fst->DeleteStates(dead);
173 }
174
175 template <class Arc, class ArcFilter,
176           typename std::enable_if<
177               (Arc::Weight::Properties() & kPath) != kPath>::type * = nullptr>
178 void Prune(MutableFst<Arc> *fst, const PruneOptions<Arc, ArcFilter> &) {
179   FSTERROR() << "Prune: Weight needs to have the path property: "
180              << Arc::Weight::Type();
181   fst->SetProperties(kError, kError);
182 }
183
184 // Pruning algorithm: this version modifies its input and takes the
185 // pruning threshold as an argument. It deletes states and arcs in the
186 // FST that do not belong to a successful path whose weight is more
187 // than the weight of the shortest path Times() the provided weight
188 // threshold. When the state threshold is not kNoStateId, the output
189 // FST is further restricted to have no more than the number of states
190 // in opts.state_threshold. Weights must have the path property. The
191 // weight of any cycle needs to be bounded; i.e.,
192 //
193 //   Plus(weight, Weight::One()) == Weight::One()
194 template <class Arc>
195 void Prune(MutableFst<Arc> *fst, typename Arc::Weight weight_threshold,
196            typename Arc::StateId state_threshold = kNoStateId,
197            float delta = kDelta) {
198   const PruneOptions<Arc, AnyArcFilter<Arc>> opts(
199       weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
200   Prune(fst, opts);
201 }
202
203 // Pruning algorithm: this version writes the pruned input FST to an
204 // output MutableFst and it takes an options class as an argument. The
205 // output FST contains states and arcs that belong to a successful
206 // path in the input FST whose weight is more than the weight of the
207 // shortest path Times() the provided weight threshold. When the state
208 // threshold is not kNoStateId, the output FST is further restricted
209 // to have no more than the number of states in
210 // opts.state_threshold. Weights have the path property.  The weight
211 // of any cycle needs to be bounded; i.e.,
212 //
213 //   Plus(weight, Weight::One()) == Weight::One()
214 template <class Arc, class ArcFilter,
215           typename std::enable_if<
216               (Arc::Weight::Properties() & kPath) == kPath>::type * = nullptr>
217 void Prune(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
218            const PruneOptions<Arc, ArcFilter> &opts) {
219   using StateId = typename Arc::StateId;
220   using Weight = typename Arc::Weight;
221   using StateHeap = Heap<StateId, internal::PruneCompare<StateId, Weight>>;
222   ofst->DeleteStates();
223   ofst->SetInputSymbols(ifst.InputSymbols());
224   ofst->SetOutputSymbols(ifst.OutputSymbols());
225   if (ifst.Start() == kNoStateId) return;
226   NaturalLess<Weight> less;
227   if (less(opts.weight_threshold, Weight::One()) ||
228       (opts.state_threshold == 0)) {
229     return;
230   }
231   std::vector<Weight> idistance;
232   std::vector<Weight> tmp;
233   if (!opts.distance) ShortestDistance(ifst, &tmp, true, opts.delta);
234   const auto *fdistance = opts.distance ? opts.distance : &tmp;
235   if ((fdistance->size() <= ifst.Start()) ||
236       ((*fdistance)[ifst.Start()] == Weight::Zero())) {
237     return;
238   }
239   internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
240   StateHeap heap(compare);
241   std::vector<StateId> copy;
242   std::vector<size_t> enqueued;
243   std::vector<bool> visited;
244   auto s = ifst.Start();
245   const auto limit = opts.threshold_initial ?
246       Times(opts.weight_threshold, (*fdistance)[s]) :
247       Times((*fdistance)[s], opts.weight_threshold);
248   while (copy.size() <= s) copy.push_back(kNoStateId);
249   copy[s] = ofst->AddState();
250   ofst->SetStart(copy[s]);
251   while (idistance.size() <= s) idistance.push_back(Weight::Zero());
252   idistance[s] = Weight::One();
253   while (enqueued.size() <= s) {
254     enqueued.push_back(StateHeap::kNoKey);
255     visited.push_back(false);
256   }
257   enqueued[s] = heap.Insert(s);
258   while (!heap.Empty()) {
259     s = heap.Top();
260     heap.Pop();
261     enqueued[s] = StateHeap::kNoKey;
262     visited[s] = true;
263     if (!less(limit, Times(idistance[s], ifst.Final(s)))) {
264       ofst->SetFinal(copy[s], ifst.Final(s));
265     }
266     for (ArcIterator<Fst<Arc>> aiter(ifst, s); !aiter.Done(); aiter.Next()) {
267       const auto &arc = aiter.Value();
268       if (!opts.filter(arc)) continue;
269       const auto weight = Times(Times(idistance[s], arc.weight),
270                                 arc.nextstate < fdistance->size() ?
271                                 (*fdistance)[arc.nextstate] : Weight::Zero());
272       if (less(limit, weight)) continue;
273       if ((opts.state_threshold != kNoStateId) &&
274           (ofst->NumStates() >= opts.state_threshold)) {
275         continue;
276       }
277       while (idistance.size() <= arc.nextstate) {
278         idistance.push_back(Weight::Zero());
279       }
280       if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
281         idistance[arc.nextstate] = Times(idistance[s], arc.weight);
282       }
283       while (copy.size() <= arc.nextstate) copy.push_back(kNoStateId);
284       if (copy[arc.nextstate] == kNoStateId) {
285         copy[arc.nextstate] = ofst->AddState();
286       }
287       ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight,
288                                 copy[arc.nextstate]));
289       while (enqueued.size() <= arc.nextstate) {
290         enqueued.push_back(StateHeap::kNoKey);
291         visited.push_back(false);
292       }
293       if (visited[arc.nextstate]) continue;
294       if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
295         enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
296       } else {
297         heap.Update(enqueued[arc.nextstate], arc.nextstate);
298       }
299     }
300   }
301 }
302
303 template <class Arc, class ArcFilter,
304           typename std::enable_if<
305               (Arc::Weight::Properties() & kPath) != kPath>::type * = nullptr>
306 void Prune(const Fst<Arc> &, MutableFst<Arc> *ofst,
307            const PruneOptions<Arc, ArcFilter> &) {
308   FSTERROR() << "Prune: Weight needs to have the path property: "
309              << Arc::Weight::Type();
310   ofst->SetProperties(kError, kError);
311 }
312
313 // Pruning algorithm: this version writes the pruned input FST to an
314 // output MutableFst and simply takes the pruning threshold as an
315 // argument. The output FST contains states and arcs that belong to a
316 // successful path in the input FST whose weight is no more than the
317 // weight of the shortest path Times() the provided weight
318 // threshold. When the state threshold is not kNoStateId, the output
319 // FST is further restricted to have no more than the number of states
320 // in opts.state_threshold. Weights must have the path property. The
321 // weight of any cycle needs to be bounded; i.e.,
322 //
323 // Plus(weight, Weight::One()) = Weight::One();
324 template <class Arc>
325 void Prune(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
326            typename Arc::Weight weight_threshold,
327            typename Arc::StateId state_threshold = kNoStateId,
328            float delta = kDelta) {
329   const PruneOptions<Arc, AnyArcFilter<Arc>> opts(
330       weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
331   Prune(ifst, ofst, opts);
332 }
333
334 }  // namespace fst
335
336 #endif  // FST_PRUNE_H_