49daa03cebe6015d0ec5e6c1f6a53f69d3a0ff24
[platform/upstream/openfst.git] / src / include / fst / prune.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions implementing pruning.
5
6 #ifndef FST_LIB_PRUNE_H_
7 #define FST_LIB_PRUNE_H_
8
9 #include <utility>
10 #include <vector>
11
12 #include <fst/log.h>
13
14 #include <fst/arcfilter.h>
15 #include <fst/heap.h>
16 #include <fst/shortest-distance.h>
17
18
19 namespace fst {
20 namespace internal {
21
22 template <class StateId, class Weight>
23 class PruneCompare {
24  public:
25   PruneCompare(const std::vector<Weight> &idistance,
26                const std::vector<Weight> &fdistance)
27       : idistance_(idistance), fdistance_(fdistance) {}
28
29   bool operator()(const StateId x, const StateId y) const {
30     const auto wx = Times(IDistance(x), FDistance(x));
31     const auto wy = Times(IDistance(y), FDistance(y));
32     return less_(wx, wy);
33   }
34
35  private:
36   Weight IDistance(const StateId s) const {
37     return s < idistance_.size() ? idistance_[s] : Weight::Zero();
38   }
39
40   Weight FDistance(const StateId s) const {
41     return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
42   }
43
44   const std::vector<Weight> &idistance_;
45   const std::vector<Weight> &fdistance_;
46   NaturalLess<Weight> less_;
47 };
48
49 }  // namespace internal
50
51 template <class Arc, class ArcFilter>
52 struct PruneOptions {
53   using StateId = typename Arc::StateId;
54   using Weight = typename Arc::Weight;
55
56   PruneOptions(const Weight &weight_threshold, StateId state_threshold,
57                ArcFilter filter, std::vector<Weight> *distance = nullptr,
58                float delta = kDelta, bool threshold_initial = false)
59       : weight_threshold(std::move(weight_threshold)),
60         state_threshold(state_threshold),
61         filter(std::move(filter)),
62         distance(distance),
63         delta(delta),
64         threshold_initial(threshold_initial) {}
65
66   // Pruning weight threshold.
67   Weight weight_threshold;
68   // Pruning state threshold.
69   StateId state_threshold;
70   // Arc filter.
71   ArcFilter filter;
72   // If non-zero, passes in pre-computed shortest distance to final states.
73   const std::vector<Weight> *distance;
74   // Determines the degree of convergence required when computing shortest
75   // distances.
76   float delta;
77   // Determines if the shortest path weight is left (true) or right
78   // (false) multiplied by the threshold to get the limit for
79   // keeping a state or arc (matters if the semiring is not
80   // commutative).
81   bool threshold_initial;
82 };
83
84 // Pruning algorithm: this version modifies its input and it takes an options
85 // class as an argument. After pruning the FST contains states and arcs that
86 // belong to a successful path in the FST whose weight is no more than the
87 // weight of the shortest path Times() the provided weight threshold. When the
88 // state threshold is not kNoStateId, the output FST is further restricted to
89 // have no more than the number of states in opts.state_threshold. Weights must
90 // have the path property. The weight of any cycle needs to be bounded; i.e.,
91 //
92 //   Plus(weight, Weight::One()) == Weight::One()
93 template <class Arc, class ArcFilter>
94 void Prune(MutableFst<Arc> *fst, const PruneOptions<Arc, ArcFilter> &opts) {
95   using StateId = typename Arc::StateId;
96   using Weight = typename Arc::Weight;
97   using StateHeap = Heap<StateId, internal::PruneCompare<StateId, Weight>>;
98   // TODO(kbg): Make this a compile-time static_assert once:
99   // 1) All weight properties are made constexpr for all weight types.
100   // 2) We have a pleasant way to "deregister" this operation for non-path
101   //    semirings so an informative error message is produced. The best
102   //    solution will probably involve some kind of SFINAE magic.
103   if ((Weight::Properties() & kPath) != kPath) {
104     FSTERROR() << "Prune: Weight needs to have the path property: "
105                << Weight::Type();
106     fst->SetProperties(kError, kError);
107     return;
108   }
109   auto ns = fst->NumStates();
110   if (ns < 1) return;
111   std::vector<Weight> idistance(ns, Weight::Zero());
112   std::vector<Weight> tmp;
113   if (!opts.distance) {
114     tmp.reserve(ns);
115     ShortestDistance(*fst, &tmp, true, opts.delta);
116   }
117   const auto *fdistance = opts.distance ? opts.distance : &tmp;
118   if ((opts.state_threshold == 0) || (fdistance->size() <= fst->Start()) ||
119       ((*fdistance)[fst->Start()] == Weight::Zero())) {
120     fst->DeleteStates();
121     return;
122   }
123   internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
124   StateHeap heap(compare);
125   std::vector<bool> visited(ns, false);
126   std::vector<size_t> enqueued(ns, StateHeap::kNoKey);
127   std::vector<StateId> dead;
128   dead.push_back(fst->AddState());
129   NaturalLess<Weight> less;
130   auto s = fst->Start();
131   const auto limit = opts.threshold_initial ?
132       Times(opts.weight_threshold, (*fdistance)[s]) :
133       Times((*fdistance)[s], opts.weight_threshold);
134   StateId num_visited = 0;
135
136   if (!less(limit, (*fdistance)[s])) {
137     idistance[s] = Weight::One();
138     enqueued[s] = heap.Insert(s);
139     ++num_visited;
140   }
141   while (!heap.Empty()) {
142     s = heap.Top();
143     heap.Pop();
144     enqueued[s] = StateHeap::kNoKey;
145     visited[s] = true;
146     if (less(limit, Times(idistance[s], fst->Final(s)))) {
147       fst->SetFinal(s, Weight::Zero());
148     }
149     for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
150          aiter.Next()) {
151       auto arc = aiter.Value();  // Copy intended.
152       if (!opts.filter(arc)) continue;
153       const auto weight = Times(Times(idistance[s], arc.weight),
154                                 arc.nextstate < fdistance->size() ?
155                                 (*fdistance)[arc.nextstate] : Weight::Zero());
156       if (less(limit, weight)) {
157         arc.nextstate = dead[0];
158         aiter.SetValue(arc);
159         continue;
160       }
161       if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
162         idistance[arc.nextstate] = Times(idistance[s], arc.weight);
163       }
164       if (visited[arc.nextstate]) continue;
165       if ((opts.state_threshold != kNoStateId) &&
166           (num_visited >= opts.state_threshold)) {
167         continue;
168       }
169       if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
170         enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
171         ++num_visited;
172       } else {
173         heap.Update(enqueued[arc.nextstate], arc.nextstate);
174       }
175     }
176   }
177   for (StateId i = 0; i < visited.size(); ++i) {
178     if (!visited[i]) dead.push_back(i);
179   }
180   fst->DeleteStates(dead);
181 }
182
183 // Pruning algorithm: this version modifies its input and takes the
184 // pruning threshold as an argument. It deletes states and arcs in the
185 // FST that do not belong to a successful path whose weight is more
186 // than the weight of the shortest path Times() the provided weight
187 // threshold. When the state threshold is not kNoStateId, the output
188 // FST is further restricted to have no more than the number of states
189 // in opts.state_threshold. Weights must have the path property. The
190 // weight of any cycle needs to be bounded; i.e.,
191 //
192 //   Plus(weight, Weight::One()) == Weight::One()
193 template <class Arc>
194 void Prune(MutableFst<Arc> *fst, typename Arc::Weight weight_threshold,
195            typename Arc::StateId state_threshold = kNoStateId,
196            double delta = kDelta) {
197   const PruneOptions<Arc, AnyArcFilter<Arc>> opts(
198       weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
199   Prune(fst, opts);
200 }
201
202 // Pruning algorithm: this version writes the pruned input FST to an
203 // output MutableFst and it takes an options class as an argument. The
204 // output FST contains states and arcs that belong to a successful
205 // path in the input FST whose weight is more than the weight of the
206 // shortest path Times() the provided weight threshold. When the state
207 // threshold is not kNoStateId, the output FST is further restricted
208 // to have no more than the number of states in
209 // opts.state_threshold. Weights have the path property.  The weight
210 // of any cycle needs to be bounded; i.e.,
211 //
212 //   Plus(weight, Weight::One()) == Weight::One()
213 template <class Arc, class ArcFilter>
214 void Prune(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
215            const PruneOptions<Arc, ArcFilter> &opts) {
216   using StateId = typename Arc::StateId;
217   using Weight = typename Arc::Weight;
218   using StateHeap = Heap<StateId, internal::PruneCompare<StateId, Weight>>;
219   // TODO(kbg): Make this a compile-time static_assert once:
220   // 1) All weight properties are made constexpr for all weight types.
221   // 2) We have a pleasant way to "deregister" this operation for non-path
222   //    semirings so an informative error message is produced. The best
223   //    solution will probably involve some kind of SFINAE magic.
224   if ((Weight::Properties() & kPath) != kPath) {
225     FSTERROR() << "Prune: Weight needs to have the path property: "
226                << Weight::Type();
227     ofst->SetProperties(kError, kError);
228     return;
229   }
230   ofst->DeleteStates();
231   ofst->SetInputSymbols(ifst.InputSymbols());
232   ofst->SetOutputSymbols(ifst.OutputSymbols());
233   if (ifst.Start() == kNoStateId) return;
234   NaturalLess<Weight> less;
235   if (less(opts.weight_threshold, Weight::One()) ||
236       (opts.state_threshold == 0)) {
237     return;
238   }
239   std::vector<Weight> idistance;
240   std::vector<Weight> tmp;
241   if (!opts.distance) ShortestDistance(ifst, &tmp, true, opts.delta);
242   const auto *fdistance = opts.distance ? opts.distance : &tmp;
243   if ((fdistance->size() <= ifst.Start()) ||
244       ((*fdistance)[ifst.Start()] == Weight::Zero())) {
245     return;
246   }
247   internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
248   StateHeap heap(compare);
249   std::vector<StateId> copy;
250   std::vector<size_t> enqueued;
251   std::vector<bool> visited;
252   auto s = ifst.Start();
253   const auto limit = opts.threshold_initial ?
254       Times(opts.weight_threshold, (*fdistance)[s]) :
255       Times((*fdistance)[s], opts.weight_threshold);
256   while (copy.size() <= s) copy.push_back(kNoStateId);
257   copy[s] = ofst->AddState();
258   ofst->SetStart(copy[s]);
259   while (idistance.size() <= s) idistance.push_back(Weight::Zero());
260   idistance[s] = Weight::One();
261   while (enqueued.size() <= s) {
262     enqueued.push_back(StateHeap::kNoKey);
263     visited.push_back(false);
264   }
265   enqueued[s] = heap.Insert(s);
266   while (!heap.Empty()) {
267     s = heap.Top();
268     heap.Pop();
269     enqueued[s] = StateHeap::kNoKey;
270     visited[s] = true;
271     if (!less(limit, Times(idistance[s], ifst.Final(s)))) {
272       ofst->SetFinal(copy[s], ifst.Final(s));
273     }
274     for (ArcIterator<Fst<Arc>> aiter(ifst, s); !aiter.Done(); aiter.Next()) {
275       const auto &arc = aiter.Value();
276       if (!opts.filter(arc)) continue;
277       const auto weight = Times(Times(idistance[s], arc.weight),
278                                 arc.nextstate < fdistance->size() ?
279                                 (*fdistance)[arc.nextstate] : Weight::Zero());
280       if (less(limit, weight)) continue;
281       if ((opts.state_threshold != kNoStateId) &&
282           (ofst->NumStates() >= opts.state_threshold)) {
283         continue;
284       }
285       while (idistance.size() <= arc.nextstate) {
286         idistance.push_back(Weight::Zero());
287       }
288       if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
289         idistance[arc.nextstate] = Times(idistance[s], arc.weight);
290       }
291       while (copy.size() <= arc.nextstate) copy.push_back(kNoStateId);
292       if (copy[arc.nextstate] == kNoStateId) {
293         copy[arc.nextstate] = ofst->AddState();
294       }
295       ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight,
296                                 copy[arc.nextstate]));
297       while (enqueued.size() <= arc.nextstate) {
298         enqueued.push_back(StateHeap::kNoKey);
299         visited.push_back(false);
300       }
301       if (visited[arc.nextstate]) continue;
302       if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
303         enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
304       } else {
305         heap.Update(enqueued[arc.nextstate], arc.nextstate);
306       }
307     }
308   }
309 }
310
311 // Pruning algorithm: this version writes the pruned input FST to an
312 // output MutableFst and simply takes the pruning threshold as an
313 // argument. The output FST contains states and arcs that belong to a
314 // successful path in the input FST whose weight is no more than the
315 // weight of the shortest path Times() the provided weight
316 // threshold. When the state threshold is not kNoStateId, the output
317 // FST is further restricted to have no more than the number of states
318 // in opts.state_threshold. Weights must have the path property. The
319 // weight of any cycle needs to be bounded; i.e.,
320 //
321 // Plus(weight, Weight::One()) = Weight::One();
322 template <class Arc>
323 void Prune(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
324            typename Arc::Weight weight_threshold,
325            typename Arc::StateId state_threshold = kNoStateId,
326            float delta = kDelta) {
327   const PruneOptions<Arc, AnyArcFilter<Arc>> opts(
328       weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
329   Prune(ifst, ofst, opts);
330 }
331
332 }  // namespace fst
333
334 #endif  // FST_LIB_PRUNE_H_