Imported Upstream version 1.6.6
[platform/upstream/openfst.git] / src / include / fst / signed-log-weight.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // LogWeight along with sign information that represents the value X in the
5 // linear domain as <sign(X), -ln(|X|)>
6 //
7 // The sign is a TropicalWeight:
8 //  positive, TropicalWeight.Value() > 0.0, recommended value 1.0
9 //  negative, TropicalWeight.Value() <= 0.0, recommended value -1.0
10
11 #ifndef FST_SIGNED_LOG_WEIGHT_H_
12 #define FST_SIGNED_LOG_WEIGHT_H_
13
14 #include <cstdlib>
15
16 #include <fst/float-weight.h>
17 #include <fst/pair-weight.h>
18 #include <fst/product-weight.h>
19
20
21 namespace fst {
22 template <class T>
23 class SignedLogWeightTpl : public PairWeight<TropicalWeight, LogWeightTpl<T>> {
24  public:
25   using X1 = TropicalWeight;
26   using X2 = LogWeightTpl<T>;
27   using ReverseWeight = SignedLogWeightTpl;
28
29   using PairWeight<X1, X2>::Value1;
30   using PairWeight<X1, X2>::Value2;
31
32   SignedLogWeightTpl() : PairWeight<X1, X2>() {}
33
34   SignedLogWeightTpl(const SignedLogWeightTpl &w) : PairWeight<X1, X2>(w) {}
35
36   explicit SignedLogWeightTpl(const PairWeight<X1, X2> &w)
37       : PairWeight<X1, X2>(w) {}
38
39   SignedLogWeightTpl(const X1 &x1, const X2 &x2) : PairWeight<X1, X2>(x1, x2) {}
40
41   static const SignedLogWeightTpl &Zero() {
42     static const SignedLogWeightTpl zero(X1(1.0), X2::Zero());
43     return zero;
44   }
45
46   static const SignedLogWeightTpl &One() {
47     static const SignedLogWeightTpl one(X1(1.0), X2::One());
48     return one;
49   }
50
51   static const SignedLogWeightTpl &NoWeight() {
52     static const SignedLogWeightTpl no_weight(X1(1.0), X2::NoWeight());
53     return no_weight;
54   }
55
56   static const string &Type() {
57     static const string *const type =
58         new string("signed_log_" + X1::Type() + "_" + X2::Type());
59     return *type;
60   }
61
62   SignedLogWeightTpl Quantize(float delta = kDelta) const {
63     return SignedLogWeightTpl(PairWeight<X1, X2>::Quantize(delta));
64   }
65
66   ReverseWeight Reverse() const {
67     return SignedLogWeightTpl(PairWeight<X1, X2>::Reverse());
68   }
69
70   bool Member() const { return PairWeight<X1, X2>::Member(); }
71
72   // Neither idempotent nor path.
73   static constexpr uint64 Properties() {
74     return kLeftSemiring | kRightSemiring | kCommutative;
75   }
76
77   size_t Hash() const {
78     size_t h1;
79     if (Value2() == X2::Zero() || Value1().Value() > 0.0) {
80       h1 = TropicalWeight(1.0).Hash();
81     } else {
82       h1 = TropicalWeight(-1.0).Hash();
83     }
84     size_t h2 = Value2().Hash();
85     static constexpr int lshift = 5;
86     static constexpr int rshift = CHAR_BIT * sizeof(size_t) - 5;
87     return h1 << lshift ^ h1 >> rshift ^ h2;
88   }
89 };
90
91 template <class T>
92 inline SignedLogWeightTpl<T> Plus(const SignedLogWeightTpl<T> &w1,
93                                   const SignedLogWeightTpl<T> &w2) {
94   using X1 = TropicalWeight;
95   using X2 = LogWeightTpl<T>;
96   if (!w1.Member() || !w2.Member()) return SignedLogWeightTpl<T>::NoWeight();
97   const auto s1 = w1.Value1().Value() > 0.0;
98   const auto s2 = w2.Value1().Value() > 0.0;
99   const bool equal = (s1 == s2);
100   const auto f1 = w1.Value2().Value();
101   const auto f2 = w2.Value2().Value();
102   if (f1 == FloatLimits<T>::PosInfinity()) {
103     return w2;
104   } else if (f2 == FloatLimits<T>::PosInfinity()) {
105     return w1;
106   } else if (f1 == f2) {
107     if (equal) {
108       return SignedLogWeightTpl<T>(X1(w1.Value1()), X2(f2 - log(2.0F)));
109     } else {
110       return SignedLogWeightTpl<T>::Zero();
111     }
112   } else if (f1 > f2) {
113     if (equal) {
114       return SignedLogWeightTpl<T>(X1(w1.Value1()),
115                                    X2(f2 - internal::LogPosExp(f1 - f2)));
116     } else {
117       return SignedLogWeightTpl<T>(X1(w2.Value1()),
118                                    X2((f2 - internal::LogNegExp(f1 - f2))));
119     }
120   } else {
121     if (equal) {
122       return SignedLogWeightTpl<T>(X1(w2.Value1()),
123                                    X2((f1 - internal::LogPosExp(f2 - f1))));
124     } else {
125       return SignedLogWeightTpl<T>(X1(w1.Value1()),
126                                    X2((f1 - internal::LogNegExp(f2 - f1))));
127     }
128   }
129 }
130
131 template <class T>
132 inline SignedLogWeightTpl<T> Minus(const SignedLogWeightTpl<T> &w1,
133                                    const SignedLogWeightTpl<T> &w2) {
134   SignedLogWeightTpl<T> minus_w2(-w2.Value1().Value(), w2.Value2());
135   return Plus(w1, minus_w2);
136 }
137
138 template <class T>
139 inline SignedLogWeightTpl<T> Times(const SignedLogWeightTpl<T> &w1,
140                                    const SignedLogWeightTpl<T> &w2) {
141   using X2 = LogWeightTpl<T>;
142   if (!w1.Member() || !w2.Member()) return SignedLogWeightTpl<T>::NoWeight();
143   const auto s1 = w1.Value1().Value() > 0.0;
144   const auto s2 = w2.Value1().Value() > 0.0;
145   const auto f1 = w1.Value2().Value();
146   const auto f2 = w2.Value2().Value();
147   if (s1 == s2) {
148     return SignedLogWeightTpl<T>(TropicalWeight(1.0), X2(f1 + f2));
149   } else {
150     return SignedLogWeightTpl<T>(TropicalWeight(-1.0), X2(f1 + f2));
151   }
152 }
153
154 template <class T>
155 inline SignedLogWeightTpl<T> Divide(const SignedLogWeightTpl<T> &w1,
156                                     const SignedLogWeightTpl<T> &w2,
157                                     DivideType typ = DIVIDE_ANY) {
158   using X2 = LogWeightTpl<T>;
159   if (!w1.Member() || !w2.Member()) return SignedLogWeightTpl<T>::NoWeight();
160   const auto s1 = w1.Value1().Value() > 0.0;
161   const auto s2 = w2.Value1().Value() > 0.0;
162   const auto f1 = w1.Value2().Value();
163   const auto f2 = w2.Value2().Value();
164   if (f2 == FloatLimits<T>::PosInfinity()) {
165     return SignedLogWeightTpl<T>(TropicalWeight(1.0),
166                                  X2(FloatLimits<T>::NumberBad()));
167   } else if (f1 == FloatLimits<T>::PosInfinity()) {
168     return SignedLogWeightTpl<T>(TropicalWeight(1.0),
169                                  X2(FloatLimits<T>::PosInfinity()));
170   } else if (s1 == s2) {
171     return SignedLogWeightTpl<T>(TropicalWeight(1.0), X2(f1 - f2));
172   } else {
173     return SignedLogWeightTpl<T>(TropicalWeight(-1.0), X2(f1 - f2));
174   }
175 }
176
177 template <class T>
178 inline bool ApproxEqual(const SignedLogWeightTpl<T> &w1,
179                         const SignedLogWeightTpl<T> &w2, float delta = kDelta) {
180   const auto s1 = w1.Value1().Value() > 0.0;
181   const auto s2 = w2.Value1().Value() > 0.0;
182   if (s1 == s2) {
183     return ApproxEqual(w1.Value2(), w2.Value2(), delta);
184   } else {
185     return w1.Value2() == LogWeightTpl<T>::Zero() &&
186            w2.Value2() == LogWeightTpl<T>::Zero();
187   }
188 }
189
190 template <class T>
191 inline bool operator==(const SignedLogWeightTpl<T> &w1,
192                        const SignedLogWeightTpl<T> &w2) {
193   const auto s1 = w1.Value1().Value() > 0.0;
194   const auto s2 = w2.Value1().Value() > 0.0;
195   if (s1 == s2) {
196     return w1.Value2() == w2.Value2();
197   } else {
198     return (w1.Value2() == LogWeightTpl<T>::Zero()) &&
199            (w2.Value2() == LogWeightTpl<T>::Zero());
200   }
201 }
202
203 // Single-precision signed-log weight.
204 using SignedLogWeight = SignedLogWeightTpl<float>;
205
206 // Double-precision signed-log weight.
207 using SignedLog64Weight = SignedLogWeightTpl<double>;
208
209 template <class W1, class W2>
210 bool SignedLogConvertCheck(W1 weight) {
211   if (weight.Value1().Value() < 0.0) {
212     FSTERROR() << "WeightConvert: Can't convert weight " << weight
213                << " from " << W1::Type() << " to " << W2::Type();
214     return false;
215   }
216   return true;
217 }
218
219 // Specialization using the Kahan compensated summation
220 template <class T>
221 class Adder<SignedLogWeightTpl<T>> {
222  public:
223   using Weight = SignedLogWeightTpl<T>;
224   using X1 = TropicalWeight;
225   using X2 = LogWeightTpl<T>;
226
227   explicit Adder(Weight w = Weight::Zero())
228      : ssum_(w.Value1().Value() > 0.0),
229         sum_(w.Value2().Value()),
230         c_(0.0) { }
231
232   Weight Add(const Weight &w) {
233     const auto sw = w.Value1().Value() > 0.0;
234     const auto f = w.Value2().Value();
235     const bool equal = (ssum_ == sw);
236
237     if (!Sum().Member() || f == FloatLimits<T>::PosInfinity()) {
238       return Sum();
239     } else if (!w.Member() || sum_ == FloatLimits<T>::PosInfinity()) {
240       sum_ = f;
241       ssum_ = sw;
242       c_ = 0.0;
243     } else if (f == sum_) {
244       if (equal) {
245         sum_ = internal::KahanLogSum(sum_, f, &c_);
246       } else {
247         sum_ = FloatLimits<T>::PosInfinity();
248         ssum_ = true;
249         c_ = 0.0;
250       }
251     } else if (f > sum_) {
252       if (equal) {
253         sum_ = internal::KahanLogSum(sum_, f, &c_);
254       } else {
255         sum_ = internal::KahanLogDiff(sum_, f, &c_);
256       }
257     } else {
258       if (equal) {
259         sum_ = internal::KahanLogSum(f, sum_, &c_);
260       } else {
261         sum_ = internal::KahanLogDiff(f, sum_, &c_);
262         ssum_ = sw;
263       }
264     }
265     return Sum();
266   }
267
268   Weight Sum() { return Weight(X1(ssum_ ? 1.0 : -1.0), X2(sum_)); }
269
270   void Reset(Weight w = Weight::Zero()) {
271     ssum_ = w.Value1().Value() > 0.0;
272     sum_ = w.Value2().Value();
273     c_ = 0.0;
274   }
275
276  private:
277   bool ssum_;   // true iff sign of sum is positive
278   double sum_;  // unsigned sum
279   double c_;    // Kahan compensation
280 };
281
282 // Converts to tropical.
283 template <>
284 struct WeightConvert<SignedLogWeight, TropicalWeight> {
285   TropicalWeight operator()(const SignedLogWeight &weight) const {
286     if (!SignedLogConvertCheck<SignedLogWeight, TropicalWeight>(weight)) {
287       return TropicalWeight::NoWeight();
288     }
289     return TropicalWeight(weight.Value2().Value());
290   }
291 };
292
293 template <>
294 struct WeightConvert<SignedLog64Weight, TropicalWeight> {
295   TropicalWeight operator()(const SignedLog64Weight &weight) const {
296     if (!SignedLogConvertCheck<SignedLog64Weight, TropicalWeight>(weight)) {
297       return TropicalWeight::NoWeight();
298     }
299     return TropicalWeight(weight.Value2().Value());
300   }
301 };
302
303 // Converts to log.
304 template <>
305 struct WeightConvert<SignedLogWeight, LogWeight> {
306   LogWeight operator()(const SignedLogWeight &weight) const {
307     if (!SignedLogConvertCheck<SignedLogWeight, LogWeight>(weight)) {
308       return LogWeight::NoWeight();
309     }
310     return LogWeight(weight.Value2().Value());
311   }
312 };
313
314 template <>
315 struct WeightConvert<SignedLog64Weight, LogWeight> {
316   LogWeight operator()(const SignedLog64Weight &weight) const {
317     if (!SignedLogConvertCheck<SignedLog64Weight, LogWeight>(weight)) {
318       return LogWeight::NoWeight();
319     }
320     return LogWeight(weight.Value2().Value());
321   }
322 };
323
324 // Converts to log64.
325 template <>
326 struct WeightConvert<SignedLogWeight, Log64Weight> {
327   Log64Weight operator()(const SignedLogWeight &weight) const {
328     if (!SignedLogConvertCheck<SignedLogWeight, Log64Weight>(weight)) {
329       return Log64Weight::NoWeight();
330     }
331     return Log64Weight(weight.Value2().Value());
332   }
333 };
334
335 template <>
336 struct WeightConvert<SignedLog64Weight, Log64Weight> {
337   Log64Weight operator()(const SignedLog64Weight &weight) const {
338     if (!SignedLogConvertCheck<SignedLog64Weight, Log64Weight>(weight)) {
339       return Log64Weight::NoWeight();
340     }
341     return Log64Weight(weight.Value2().Value());
342   }
343 };
344
345 // Converts to signed log.
346 template <>
347 struct WeightConvert<TropicalWeight, SignedLogWeight> {
348   SignedLogWeight operator()(const TropicalWeight &weight) const {
349     return SignedLogWeight(1.0, weight.Value());
350   }
351 };
352
353 template <>
354 struct WeightConvert<LogWeight, SignedLogWeight> {
355   SignedLogWeight operator()(const LogWeight &weight) const {
356     return SignedLogWeight(1.0, weight.Value());
357   }
358 };
359
360 template <>
361 struct WeightConvert<Log64Weight, SignedLogWeight> {
362   SignedLogWeight operator()(const Log64Weight &weight) const {
363     return SignedLogWeight(1.0, weight.Value());
364   }
365 };
366
367 template <>
368 struct WeightConvert<SignedLog64Weight, SignedLogWeight> {
369   SignedLogWeight operator()(const SignedLog64Weight &weight) const {
370     return SignedLogWeight(weight.Value1(), weight.Value2().Value());
371   }
372 };
373
374 // Converts to signed log64.
375 template <>
376 struct WeightConvert<TropicalWeight, SignedLog64Weight> {
377   SignedLog64Weight operator()(const TropicalWeight &weight) const {
378     return SignedLog64Weight(1.0, weight.Value());
379   }
380 };
381
382 template <>
383 struct WeightConvert<LogWeight, SignedLog64Weight> {
384   SignedLog64Weight operator()(const LogWeight &weight) const {
385     return SignedLog64Weight(1.0, weight.Value());
386   }
387 };
388
389 template <>
390 struct WeightConvert<Log64Weight, SignedLog64Weight> {
391   SignedLog64Weight operator()(const Log64Weight &weight) const {
392     return SignedLog64Weight(1.0, weight.Value());
393   }
394 };
395
396 template <>
397 struct WeightConvert<SignedLogWeight, SignedLog64Weight> {
398   SignedLog64Weight operator()(const SignedLogWeight &weight) const {
399     return SignedLog64Weight(weight.Value1(), weight.Value2().Value());
400   }
401 };
402
403 // This function object returns SignedLogWeightTpl<T>'s that are random integers
404 // chosen from [0, num_random_weights) times a random sign. This is intended
405 // primarily for testing.
406 template <class T>
407 class WeightGenerate<SignedLogWeightTpl<T>> {
408  public:
409   using Weight = SignedLogWeightTpl<T>;
410   using X1 = typename Weight::X1;
411   using X2 = typename Weight::X2;
412
413   explicit WeightGenerate(bool allow_zero = true,
414                           size_t num_random_weights = kNumRandomWeights)
415     : allow_zero_(allow_zero), num_random_weights_(num_random_weights) {}
416
417   Weight operator()() const {
418     static const X1 negative_one(-1.0);
419     static const X1 positive_one(+1.0);
420     const int m = rand() % 2;                                    // NOLINT
421     const int n = rand() % (num_random_weights_ + allow_zero_);  // NOLINT
422     return Weight((m == 0) ? negative_one : positive_one,
423                   (allow_zero_ && n == num_random_weights_) ?
424                    X2::Zero() : X2(n));
425   }
426
427  private:
428   // Permits Zero() and zero divisors.
429   const bool allow_zero_;
430   // Number of alternative random weights.
431   const size_t num_random_weights_;
432 };
433
434 }  // namespace fst
435
436 #endif  // FST_SIGNED_LOG_WEIGHT_H_