1a9af0a01b741b7f4bb43490843ddf3eaf3f63f8
[platform/upstream/openfst.git] / src / include / fst / reweight.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Function to reweight an FST.
5
6 #ifndef FST_LIB_REWEIGHT_H_
7 #define FST_LIB_REWEIGHT_H_
8
9 #include <vector>
10 #include <fst/log.h>
11
12 #include <fst/mutable-fst.h>
13
14
15 namespace fst {
16
17 enum ReweightType { REWEIGHT_TO_INITIAL, REWEIGHT_TO_FINAL };
18
19 // Reweights an FST according to a vector of potentials in a given direction.
20 // The weight must be left distributive when reweighting towards the initial
21 // state and right distributive when reweighting towards the final states.
22 //
23 // An arc of weight w, with an origin state of potential p and destination state
24 // of potential q, is reweighted by p^-1 \otimes (w \otimes q) when reweighting
25 // torwards the initial state, and by (p \otimes w) \otimes q^-1 when
26 // reweighting towards the final states.
27 template <class Arc>
28 void Reweight(MutableFst<Arc> *fst,
29               const std::vector<typename Arc::Weight> &potential,
30               ReweightType type) {
31   using Weight = typename Arc::Weight;
32   if (fst->NumStates() == 0) return;
33   // TODO(kbg): Make this a compile-time static_assert once:
34   // 1) All weight properties are made constexpr for all weight types.
35   // 2) We have a pleasant way to "deregister" this operation for non-path
36   //    semirings so an informative error message is produced. The best
37   //    solution will probably involve some kind of SFINAE magic.
38   if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring)) {
39     FSTERROR() << "Reweight: Reweighting to the final states requires "
40                << "Weight to be right distributive: " << Weight::Type();
41     fst->SetProperties(kError, kError);
42     return;
43   }
44   // TODO(kbg): Make this a compile-time static_assert once:
45   // 1) All weight properties are made constexpr for all weight types.
46   // 2) We have a pleasant way to "deregister" this operation for non-path
47   //    semirings so an informative error message is produced. The best
48   //    solution will probably involve some kind of SFINAE magic.
49   if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring)) {
50     FSTERROR() << "Reweight: Reweighting to the initial state requires "
51                << "Weight to be left distributive: " << Weight::Type();
52     fst->SetProperties(kError, kError);
53     return;
54   }
55   StateIterator<MutableFst<Arc>> siter(*fst);
56   for (; !siter.Done(); siter.Next()) {
57     const auto s = siter.Value();
58     if (s == potential.size()) break;
59     const auto &weight = potential[s];
60     if (weight != Weight::Zero()) {
61       for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
62            aiter.Next()) {
63         auto arc = aiter.Value();
64         if (arc.nextstate >= potential.size()) continue;
65         const auto &nextweight = potential[arc.nextstate];
66         if (nextweight == Weight::Zero()) continue;
67         if (type == REWEIGHT_TO_INITIAL) {
68           arc.weight =
69               Divide(Times(arc.weight, nextweight), weight, DIVIDE_LEFT);
70         }
71         if (type == REWEIGHT_TO_FINAL) {
72           arc.weight =
73               Divide(Times(weight, arc.weight), nextweight, DIVIDE_RIGHT);
74         }
75         aiter.SetValue(arc);
76       }
77       if (type == REWEIGHT_TO_INITIAL) {
78         fst->SetFinal(s, Divide(fst->Final(s), weight, DIVIDE_LEFT));
79       }
80     }
81     if (type == REWEIGHT_TO_FINAL) {
82       fst->SetFinal(s, Times(weight, fst->Final(s)));
83     }
84   }
85   // This handles elements past the end of the potentials array.
86   for (; !siter.Done(); siter.Next()) {
87     const auto s = siter.Value();
88     if (type == REWEIGHT_TO_FINAL) {
89       fst->SetFinal(s, Times(Weight::Zero(), fst->Final(s)));
90     }
91   }
92   const auto startweight = fst->Start() < potential.size()
93                                ? potential[fst->Start()]
94                                : Weight::Zero();
95   if ((startweight != Weight::One()) && (startweight != Weight::Zero())) {
96     if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) {
97       const auto s = fst->Start();
98       for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
99            aiter.Next()) {
100         auto arc = aiter.Value();
101         if (type == REWEIGHT_TO_INITIAL) {
102           arc.weight = Times(startweight, arc.weight);
103         } else {
104           arc.weight = Times(Divide(Weight::One(), startweight, DIVIDE_RIGHT),
105                              arc.weight);
106         }
107         aiter.SetValue(arc);
108       }
109       if (type == REWEIGHT_TO_INITIAL) {
110         fst->SetFinal(s, Times(startweight, fst->Final(s)));
111       } else {
112         fst->SetFinal(s, Times(Divide(Weight::One(), startweight, DIVIDE_RIGHT),
113                                fst->Final(s)));
114       }
115     } else {
116       const auto s = fst->AddState();
117       const auto weight =
118           (type == REWEIGHT_TO_INITIAL)
119               ? startweight
120               : Divide(Weight::One(), startweight, DIVIDE_RIGHT);
121       fst->AddArc(s, Arc(0, 0, weight, fst->Start()));
122       fst->SetStart(s);
123     }
124   }
125   fst->SetProperties(ReweightProperties(fst->Properties(kFstProperties, false)),
126                      kFstProperties);
127 }
128
129 }  // namespace fst
130
131 #endif  // FST_LIB_REWEIGHT_H_