1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Function to reweight an FST.
6 #ifndef FST_LIB_REWEIGHT_H_
7 #define FST_LIB_REWEIGHT_H_
12 #include <fst/mutable-fst.h>
17 enum ReweightType { REWEIGHT_TO_INITIAL, REWEIGHT_TO_FINAL };
19 // Reweights an FST according to a vector of potentials in a given direction.
20 // The weight must be left distributive when reweighting towards the initial
21 // state and right distributive when reweighting towards the final states.
23 // An arc of weight w, with an origin state of potential p and destination state
24 // of potential q, is reweighted by p^-1 \otimes (w \otimes q) when reweighting
25 // torwards the initial state, and by (p \otimes w) \otimes q^-1 when
26 // reweighting towards the final states.
28 void Reweight(MutableFst<Arc> *fst,
29 const std::vector<typename Arc::Weight> &potential,
31 using Weight = typename Arc::Weight;
32 if (fst->NumStates() == 0) return;
33 // TODO(kbg): Make this a compile-time static_assert once:
34 // 1) All weight properties are made constexpr for all weight types.
35 // 2) We have a pleasant way to "deregister" this operation for non-path
36 // semirings so an informative error message is produced. The best
37 // solution will probably involve some kind of SFINAE magic.
38 if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring)) {
39 FSTERROR() << "Reweight: Reweighting to the final states requires "
40 << "Weight to be right distributive: " << Weight::Type();
41 fst->SetProperties(kError, kError);
44 // TODO(kbg): Make this a compile-time static_assert once:
45 // 1) All weight properties are made constexpr for all weight types.
46 // 2) We have a pleasant way to "deregister" this operation for non-path
47 // semirings so an informative error message is produced. The best
48 // solution will probably involve some kind of SFINAE magic.
49 if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring)) {
50 FSTERROR() << "Reweight: Reweighting to the initial state requires "
51 << "Weight to be left distributive: " << Weight::Type();
52 fst->SetProperties(kError, kError);
55 StateIterator<MutableFst<Arc>> siter(*fst);
56 for (; !siter.Done(); siter.Next()) {
57 const auto s = siter.Value();
58 if (s == potential.size()) break;
59 const auto &weight = potential[s];
60 if (weight != Weight::Zero()) {
61 for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
63 auto arc = aiter.Value();
64 if (arc.nextstate >= potential.size()) continue;
65 const auto &nextweight = potential[arc.nextstate];
66 if (nextweight == Weight::Zero()) continue;
67 if (type == REWEIGHT_TO_INITIAL) {
69 Divide(Times(arc.weight, nextweight), weight, DIVIDE_LEFT);
71 if (type == REWEIGHT_TO_FINAL) {
73 Divide(Times(weight, arc.weight), nextweight, DIVIDE_RIGHT);
77 if (type == REWEIGHT_TO_INITIAL) {
78 fst->SetFinal(s, Divide(fst->Final(s), weight, DIVIDE_LEFT));
81 if (type == REWEIGHT_TO_FINAL) {
82 fst->SetFinal(s, Times(weight, fst->Final(s)));
85 // This handles elements past the end of the potentials array.
86 for (; !siter.Done(); siter.Next()) {
87 const auto s = siter.Value();
88 if (type == REWEIGHT_TO_FINAL) {
89 fst->SetFinal(s, Times(Weight::Zero(), fst->Final(s)));
92 const auto startweight = fst->Start() < potential.size()
93 ? potential[fst->Start()]
95 if ((startweight != Weight::One()) && (startweight != Weight::Zero())) {
96 if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) {
97 const auto s = fst->Start();
98 for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
100 auto arc = aiter.Value();
101 if (type == REWEIGHT_TO_INITIAL) {
102 arc.weight = Times(startweight, arc.weight);
104 arc.weight = Times(Divide(Weight::One(), startweight, DIVIDE_RIGHT),
109 if (type == REWEIGHT_TO_INITIAL) {
110 fst->SetFinal(s, Times(startweight, fst->Final(s)));
112 fst->SetFinal(s, Times(Divide(Weight::One(), startweight, DIVIDE_RIGHT),
116 const auto s = fst->AddState();
118 (type == REWEIGHT_TO_INITIAL)
120 : Divide(Weight::One(), startweight, DIVIDE_RIGHT);
121 fst->AddArc(s, Arc(0, 0, weight, fst->Start()));
125 fst->SetProperties(ReweightProperties(fst->Properties(kFstProperties, false)),
131 #endif // FST_LIB_REWEIGHT_H_