965b3c9669b79242a294c25796ad7a34203b44c6
[platform/upstream/openfst.git] / src / include / fst / float-weight.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Float weight set and associated semiring operation definitions.
5
6 #ifndef FST_LIB_FLOAT_WEIGHT_H_
7 #define FST_LIB_FLOAT_WEIGHT_H_
8
9 #include <climits>
10 #include <cmath>
11 #include <cstdlib>
12
13 #include <limits>
14 #include <sstream>
15 #include <string>
16
17 #include <fst/util.h>
18 #include <fst/weight.h>
19
20
21 namespace fst {
22
23 // Numeric limits class.
24 template <class T>
25 class FloatLimits {
26  public:
27   static constexpr T PosInfinity() {
28     return std::numeric_limits<T>::infinity();
29   }
30
31   static constexpr T NegInfinity() { return -PosInfinity(); }
32
33   static constexpr T NumberBad() { return std::numeric_limits<T>::quiet_NaN(); }
34 };
35
36 // Weight class to be templated on floating-points types.
37 template <class T = float>
38 class FloatWeightTpl {
39  public:
40   using ValueType = T;
41
42   FloatWeightTpl() {}
43
44   FloatWeightTpl(T f) : value_(f) {}
45
46   FloatWeightTpl(const FloatWeightTpl<T> &weight) : value_(weight.value_) {}
47
48   FloatWeightTpl<T> &operator=(const FloatWeightTpl<T> &weight) {
49     value_ = weight.value_;
50     return *this;
51   }
52
53   std::istream &Read(std::istream &strm) { return ReadType(strm, &value_); }
54
55   std::ostream &Write(std::ostream &strm) const {
56     return WriteType(strm, value_);
57   }
58
59   size_t Hash() const {
60     union {
61       T f;
62       size_t s;
63     } u;
64     u.s = 0;
65     u.f = value_;
66     return u.s;
67   }
68
69   const T &Value() const { return value_; }
70
71  protected:
72   void SetValue(const T &f) { value_ = f; }
73
74   static constexpr const char *GetPrecisionString() {
75     return sizeof(T) == 4
76                ? ""
77                : sizeof(T) == 1
78                      ? "8"
79                      : sizeof(T) == 2 ? "16"
80                                       : sizeof(T) == 8 ? "64" : "unknown";
81   }
82
83  private:
84   T value_;
85 };
86
87 // Single-precision float weight.
88 using FloatWeight = FloatWeightTpl<float>;
89
90 template <class T>
91 inline bool operator==(const FloatWeightTpl<T> &w1,
92                        const FloatWeightTpl<T> &w2) {
93   // Volatile qualifier thwarts over-aggressive compiler optimizations that
94   // lead to problems esp. with NaturalLess().
95   volatile T v1 = w1.Value();
96   volatile T v2 = w2.Value();
97   return v1 == v2;
98 }
99
100 inline bool operator==(const FloatWeightTpl<double> &w1,
101                        const FloatWeightTpl<double> &w2) {
102   return operator==<double>(w1, w2);
103 }
104
105 inline bool operator==(const FloatWeightTpl<float> &w1,
106                        const FloatWeightTpl<float> &w2) {
107   return operator==<float>(w1, w2);
108 }
109
110 template <class T>
111 inline bool operator!=(const FloatWeightTpl<T> &w1,
112                        const FloatWeightTpl<T> &w2) {
113   return !(w1 == w2);
114 }
115
116 inline bool operator!=(const FloatWeightTpl<double> &w1,
117                        const FloatWeightTpl<double> &w2) {
118   return operator!=<double>(w1, w2);
119 }
120
121 inline bool operator!=(const FloatWeightTpl<float> &w1,
122                        const FloatWeightTpl<float> &w2) {
123   return operator!=<float>(w1, w2);
124 }
125
126 template <class T>
127 inline bool ApproxEqual(const FloatWeightTpl<T> &w1,
128                         const FloatWeightTpl<T> &w2, float delta = kDelta) {
129   return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta;
130 }
131
132 template <class T>
133 inline std::ostream &operator<<(std::ostream &strm,
134                                 const FloatWeightTpl<T> &w) {
135   if (w.Value() == FloatLimits<T>::PosInfinity()) {
136     return strm << "Infinity";
137   } else if (w.Value() == FloatLimits<T>::NegInfinity()) {
138     return strm << "-Infinity";
139   } else if (w.Value() != w.Value()) {  // Fails for IEEE NaN.
140     return strm << "BadNumber";
141   } else {
142     return strm << w.Value();
143   }
144 }
145
146 template <class T>
147 inline std::istream &operator>>(std::istream &strm, FloatWeightTpl<T> &w) {
148   string s;
149   strm >> s;
150   if (s == "Infinity") {
151     w = FloatWeightTpl<T>(FloatLimits<T>::PosInfinity());
152   } else if (s == "-Infinity") {
153     w = FloatWeightTpl<T>(FloatLimits<T>::NegInfinity());
154   } else {
155     char *p;
156     T f = strtod(s.c_str(), &p);
157     if (p < s.c_str() + s.size()) {
158       strm.clear(std::ios::badbit);
159     } else {
160       w = FloatWeightTpl<T>(f);
161     }
162   }
163   return strm;
164 }
165
166 // Tropical semiring: (min, +, inf, 0).
167 template <class T>
168 class TropicalWeightTpl : public FloatWeightTpl<T> {
169  public:
170   using typename FloatWeightTpl<T>::ValueType;
171   using FloatWeightTpl<T>::Value;
172   using ReverseWeight = TropicalWeightTpl<T>;
173   using Limits = FloatLimits<T>;
174
175   constexpr TropicalWeightTpl() : FloatWeightTpl<T>() {}
176
177   constexpr TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {}
178
179   constexpr TropicalWeightTpl(const TropicalWeightTpl<T> &weight)
180       : FloatWeightTpl<T>(weight) {}
181
182   static const TropicalWeightTpl<T> &Zero() {
183     static const TropicalWeightTpl zero(Limits::PosInfinity());
184     return zero;
185   }
186
187   static const TropicalWeightTpl<T> &One() {
188     static const TropicalWeightTpl one(0.0F);
189     return one;
190   }
191
192   static const TropicalWeightTpl<T> &NoWeight() {
193     static const TropicalWeightTpl no_weight(Limits::NumberBad());
194     return no_weight;
195   }
196
197   static const string &Type() {
198     static const string *const type =
199         new string(string("tropical") +
200                    FloatWeightTpl<T>::GetPrecisionString());
201     return *type;
202   }
203
204   bool Member() const {
205     // First part fails for IEEE NaN.
206     return Value() == Value() && Value() != Limits::NegInfinity();
207   }
208
209   TropicalWeightTpl<T> Quantize(float delta = kDelta) const {
210     if (!Member() || Value() == Limits::PosInfinity()) {
211       return *this;
212     } else {
213       return TropicalWeightTpl<T>(floor(Value() / delta + 0.5F) * delta);
214     }
215   }
216
217   TropicalWeightTpl<T> Reverse() const { return *this; }
218
219   static constexpr uint64 Properties() {
220     return kLeftSemiring | kRightSemiring | kCommutative | kPath | kIdempotent;
221   }
222 };
223
224 // Single precision tropical weight.
225 using TropicalWeight = TropicalWeightTpl<float>;
226
227 template <class T>
228 inline TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1,
229                                  const TropicalWeightTpl<T> &w2) {
230   if (!w1.Member() || !w2.Member()) return TropicalWeightTpl<T>::NoWeight();
231   return w1.Value() < w2.Value() ? w1 : w2;
232 }
233
234 inline TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1,
235                                      const TropicalWeightTpl<float> &w2) {
236   return Plus<float>(w1, w2);
237 }
238
239 inline TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1,
240                                       const TropicalWeightTpl<double> &w2) {
241   return Plus<double>(w1, w2);
242 }
243
244 template <class T>
245 inline TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1,
246                                   const TropicalWeightTpl<T> &w2) {
247   using Limits = FloatLimits<T>;
248   if (!w1.Member() || !w2.Member()) return TropicalWeightTpl<T>::NoWeight();
249   const T f1 = w1.Value();
250   const T f2 = w2.Value();
251   if (f1 == Limits::PosInfinity()) {
252     return w1;
253   } else if (f2 == Limits::PosInfinity()) {
254     return w2;
255   } else {
256     return TropicalWeightTpl<T>(f1 + f2);
257   }
258 }
259
260 inline TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1,
261                                       const TropicalWeightTpl<float> &w2) {
262   return Times<float>(w1, w2);
263 }
264
265 inline TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1,
266                                        const TropicalWeightTpl<double> &w2) {
267   return Times<double>(w1, w2);
268 }
269
270 template <class T>
271 inline TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1,
272                                    const TropicalWeightTpl<T> &w2,
273                                    DivideType typ = DIVIDE_ANY) {
274   using Limits = FloatLimits<T>;
275   if (!w1.Member() || !w2.Member()) return TropicalWeightTpl<T>::NoWeight();
276   const T f1 = w1.Value();
277   const T f2 = w2.Value();
278   if (f2 == Limits::PosInfinity()) {
279     return Limits::NumberBad();
280   } else if (f1 == Limits::PosInfinity()) {
281     return Limits::PosInfinity();
282   } else {
283     return TropicalWeightTpl<T>(f1 - f2);
284   }
285 }
286
287 inline TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1,
288                                        const TropicalWeightTpl<float> &w2,
289                                        DivideType typ = DIVIDE_ANY) {
290   return Divide<float>(w1, w2, typ);
291 }
292
293 inline TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1,
294                                         const TropicalWeightTpl<double> &w2,
295                                         DivideType typ = DIVIDE_ANY) {
296   return Divide<double>(w1, w2, typ);
297 }
298
299 template <class T>
300 inline TropicalWeightTpl<T> Power(const TropicalWeightTpl<T> &weight,
301                                   T scalar) {
302   return TropicalWeightTpl<T>(weight.Value() * scalar);
303 }
304
305 // Log semiring: (log(e^-x + e^-y), +, inf, 0).
306 template <class T>
307 class LogWeightTpl : public FloatWeightTpl<T> {
308  public:
309   using typename FloatWeightTpl<T>::ValueType;
310   using FloatWeightTpl<T>::Value;
311   using ReverseWeight = LogWeightTpl;
312   using Limits = FloatLimits<T>;
313
314   constexpr LogWeightTpl() : FloatWeightTpl<T>() {}
315
316   constexpr LogWeightTpl(T f) : FloatWeightTpl<T>(f) {}
317
318   constexpr LogWeightTpl(const LogWeightTpl<T> &weight)
319       : FloatWeightTpl<T>(weight) {}
320
321   static const LogWeightTpl &Zero() {
322     static const LogWeightTpl zero(Limits::PosInfinity());
323     return zero;
324   }
325
326   static const LogWeightTpl &One() {
327     static const LogWeightTpl one(0.0F);
328     return one;
329   }
330
331   static const LogWeightTpl &NoWeight() {
332     static const LogWeightTpl no_weight(Limits::NumberBad());
333     return no_weight;
334   }
335
336   static const string &Type() {
337     static const string *const type =
338         new string(string("log") + FloatWeightTpl<T>::GetPrecisionString());
339     return *type;
340   }
341
342   bool Member() const {
343     // First part fails for IEEE NaN.
344     return Value() == Value() && Value() != Limits::NegInfinity();
345   }
346
347   LogWeightTpl<T> Quantize(float delta = kDelta) const {
348     if (!Member() || Value() == Limits::PosInfinity()) {
349       return *this;
350     } else {
351       return LogWeightTpl<T>(floor(Value() / delta + 0.5F) * delta);
352     }
353   }
354
355   LogWeightTpl<T> Reverse() const { return *this; }
356
357   static constexpr uint64 Properties() {
358     return kLeftSemiring | kRightSemiring | kCommutative;
359   }
360 };
361
362 // Single-precision log weight.
363 using LogWeight = LogWeightTpl<float>;
364
365 // Double-precision log weight.
366 using Log64Weight = LogWeightTpl<double>;
367
368 namespace internal {
369
370 // -log(e^-x + e^-y) = x - LogPosExp(y - x), assuming x >= 0.0.
371 inline double LogPosExp(double x) { return log1p(exp(-x)); }
372
373 // -log(e^-x - e^-y) = x - LogNegExp(y - x), assuming x > 0.0.
374 inline double LogNegExp(double x) { return log1p(-exp(-x)); }
375
376 // a +_log b = -log(e^-a + e^-b) = KahanLogSum(a, b, ...).
377 // Kahan compensated summation provides an error bound that is
378 // independent of the number of addends. Assumes b >= a;
379 // c is the compensation.
380 inline double KahanLogSum(double a, double b, double *c) {
381   double y = -LogPosExp(b - a) - *c;
382   double t = a + y;
383   *c = (t - a) - y;
384   return t;
385 }
386
387 // a -_log b = -log(e^-a - e^-b) = KahanLogDiff(a, b, ...).
388 // Kahan compensated summation provides an error bound that is
389 // independent of the number of addends. Assumes b > a;
390 // c is the compensation.
391 inline double KahanLogDiff(double a, double b, double *c) {
392   double y = -LogNegExp(b - a) - *c;
393   double t = a + y;
394   *c = (t - a) - y;
395   return t;
396 }
397
398 }  // namespace internal
399
400 template <class T>
401 inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1,
402                             const LogWeightTpl<T> &w2) {
403   using Limits = FloatLimits<T>;
404   const T f1 = w1.Value();
405   const T f2 = w2.Value();
406   if (f1 == Limits::PosInfinity()) {
407     return w2;
408   } else if (f2 == Limits::PosInfinity()) {
409     return w1;
410   } else if (f1 > f2) {
411     return LogWeightTpl<T>(f2 - internal::LogPosExp(f1 - f2));
412   } else {
413     return LogWeightTpl<T>(f1 - internal::LogPosExp(f2 - f1));
414   }
415 }
416
417 inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1,
418                                 const LogWeightTpl<float> &w2) {
419   return Plus<float>(w1, w2);
420 }
421
422 inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1,
423                                  const LogWeightTpl<double> &w2) {
424   return Plus<double>(w1, w2);
425 }
426
427 template <class T>
428 inline LogWeightTpl<T> Times(const LogWeightTpl<T> &w1,
429                              const LogWeightTpl<T> &w2) {
430   using Limits = FloatLimits<T>;
431   if (!w1.Member() || !w2.Member()) return LogWeightTpl<T>::NoWeight();
432   const T f1 = w1.Value();
433   const T f2 = w2.Value();
434   if (f1 == Limits::PosInfinity()) {
435     return w1;
436   } else if (f2 == Limits::PosInfinity()) {
437     return w2;
438   } else {
439     return LogWeightTpl<T>(f1 + f2);
440   }
441 }
442
443 inline LogWeightTpl<float> Times(const LogWeightTpl<float> &w1,
444                                  const LogWeightTpl<float> &w2) {
445   return Times<float>(w1, w2);
446 }
447
448 inline LogWeightTpl<double> Times(const LogWeightTpl<double> &w1,
449                                   const LogWeightTpl<double> &w2) {
450   return Times<double>(w1, w2);
451 }
452
453 template <class T>
454 inline LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1,
455                               const LogWeightTpl<T> &w2,
456                               DivideType typ = DIVIDE_ANY) {
457   using Limits = FloatLimits<T>;
458   if (!w1.Member() || !w2.Member()) return LogWeightTpl<T>::NoWeight();
459   const T f1 = w1.Value();
460   const T f2 = w2.Value();
461   if (f2 == Limits::PosInfinity()) {
462     return Limits::NumberBad();
463   } else if (f1 == Limits::PosInfinity()) {
464     return Limits::PosInfinity();
465   } else {
466     return LogWeightTpl<T>(f1 - f2);
467   }
468 }
469
470 inline LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1,
471                                   const LogWeightTpl<float> &w2,
472                                   DivideType typ = DIVIDE_ANY) {
473   return Divide<float>(w1, w2, typ);
474 }
475
476 inline LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1,
477                                    const LogWeightTpl<double> &w2,
478                                    DivideType typ = DIVIDE_ANY) {
479   return Divide<double>(w1, w2, typ);
480 }
481
482 template <class T>
483 inline LogWeightTpl<T> Power(const LogWeightTpl<T> &weight, T scalar) {
484   return LogWeightTpl<T>(weight.Value() * scalar);
485 }
486
487 // Specialization using the Kahan compensated summation.
488 template <class T>
489 class Adder<LogWeightTpl<T>> {
490  public:
491   using Weight = LogWeightTpl<T>;
492
493   explicit Adder(Weight w = Weight::Zero())
494       : sum_(w.Value()),
495         c_(0.0) { }
496
497   Weight Add(const Weight &w) {
498   using Limits = FloatLimits<T>;
499     const T f = w.Value();
500     if (f == Limits::PosInfinity()) {
501       return Sum();
502     } else if (sum_ == Limits::PosInfinity()) {
503       sum_ = f;
504       c_ = 0.0;
505     } else if (f > sum_) {
506       sum_ = internal::KahanLogSum(sum_, f, &c_);
507     } else {
508       sum_ = internal::KahanLogSum(f, sum_, &c_);
509     }
510     return Sum();
511   }
512
513   Weight Sum() { return Weight(sum_); }
514
515   void Reset(Weight w = Weight::Zero()) {
516     sum_ = w.Value();
517     c_ = 0.0;
518   }
519
520  private:
521   double sum_;
522   double c_;   // Kahan compensation.
523 };
524
525 // MinMax semiring: (min, max, inf, -inf).
526 template <class T>
527 class MinMaxWeightTpl : public FloatWeightTpl<T> {
528  public:
529   using typename FloatWeightTpl<T>::ValueType;
530   using FloatWeightTpl<T>::Value;
531   using ReverseWeight = MinMaxWeightTpl<T>;
532   using Limits = FloatLimits<T>;
533
534   MinMaxWeightTpl() : FloatWeightTpl<T>() {}
535
536   MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {}
537
538   MinMaxWeightTpl(const MinMaxWeightTpl<T> &weight)
539       : FloatWeightTpl<T>(weight) {}
540
541   static const MinMaxWeightTpl &Zero() {
542     static const MinMaxWeightTpl zero(Limits::PosInfinity());
543     return zero;
544   }
545
546   static const MinMaxWeightTpl &One() {
547     static const MinMaxWeightTpl one(Limits::NegInfinity());
548     return one;
549   }
550
551   static const MinMaxWeightTpl &NoWeight() {
552     static const MinMaxWeightTpl no_weight(Limits::NumberBad());
553     return no_weight;
554   }
555
556   static const string &Type() {
557     static const string *const type =
558         new string(string("minmax") + FloatWeightTpl<T>::GetPrecisionString());
559     return *type;
560   }
561
562   // Fails for IEEE NaN.
563   bool Member() const { return Value() == Value(); }
564
565   MinMaxWeightTpl<T> Quantize(float delta = kDelta) const {
566     // If one of infinities, or a NaN.
567     if (!Member() ||
568         Value() == Limits::NegInfinity() || Value() == Limits::PosInfinity()) {
569       return *this;
570     } else {
571       return MinMaxWeightTpl<T>(floor(Value() / delta + 0.5F) * delta);
572     }
573   }
574
575   MinMaxWeightTpl<T> Reverse() const { return *this; }
576
577   static constexpr uint64 Properties() {
578     return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath;
579   }
580 };
581
582 // Single-precision min-max weight.
583 using MinMaxWeight = MinMaxWeightTpl<float>;
584
585 // Min.
586 template <class T>
587 inline MinMaxWeightTpl<T> Plus(const MinMaxWeightTpl<T> &w1,
588                                const MinMaxWeightTpl<T> &w2) {
589   if (!w1.Member() || !w2.Member()) return MinMaxWeightTpl<T>::NoWeight();
590   return w1.Value() < w2.Value() ? w1 : w2;
591 }
592
593 inline MinMaxWeightTpl<float> Plus(const MinMaxWeightTpl<float> &w1,
594                                    const MinMaxWeightTpl<float> &w2) {
595   return Plus<float>(w1, w2);
596 }
597
598 inline MinMaxWeightTpl<double> Plus(const MinMaxWeightTpl<double> &w1,
599                                     const MinMaxWeightTpl<double> &w2) {
600   return Plus<double>(w1, w2);
601 }
602
603 // Max.
604 template <class T>
605 inline MinMaxWeightTpl<T> Times(const MinMaxWeightTpl<T> &w1,
606                                 const MinMaxWeightTpl<T> &w2) {
607   if (!w1.Member() || !w2.Member()) return MinMaxWeightTpl<T>::NoWeight();
608   return w1.Value() >= w2.Value() ? w1 : w2;
609 }
610
611 inline MinMaxWeightTpl<float> Times(const MinMaxWeightTpl<float> &w1,
612                                     const MinMaxWeightTpl<float> &w2) {
613   return Times<float>(w1, w2);
614 }
615
616 inline MinMaxWeightTpl<double> Times(const MinMaxWeightTpl<double> &w1,
617                                      const MinMaxWeightTpl<double> &w2) {
618   return Times<double>(w1, w2);
619 }
620
621 // Defined only for special cases.
622 template <class T>
623 inline MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1,
624                                  const MinMaxWeightTpl<T> &w2,
625                                  DivideType typ = DIVIDE_ANY) {
626   if (!w1.Member() || !w2.Member()) return MinMaxWeightTpl<T>::NoWeight();
627   // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2.
628   return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::NumberBad();
629 }
630
631 inline MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1,
632                                      const MinMaxWeightTpl<float> &w2,
633                                      DivideType typ = DIVIDE_ANY) {
634   return Divide<float>(w1, w2, typ);
635 }
636
637 inline MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1,
638                                       const MinMaxWeightTpl<double> &w2,
639                                       DivideType typ = DIVIDE_ANY) {
640   return Divide<double>(w1, w2, typ);
641 }
642
643 // Converts to tropical.
644 template <>
645 struct WeightConvert<LogWeight, TropicalWeight> {
646   TropicalWeight operator()(const LogWeight &w) const { return w.Value(); }
647 };
648
649 template <>
650 struct WeightConvert<Log64Weight, TropicalWeight> {
651   TropicalWeight operator()(const Log64Weight &w) const { return w.Value(); }
652 };
653
654 // Converts to log.
655 template <>
656 struct WeightConvert<TropicalWeight, LogWeight> {
657   LogWeight operator()(const TropicalWeight &w) const { return w.Value(); }
658 };
659
660 template <>
661 struct WeightConvert<Log64Weight, LogWeight> {
662   LogWeight operator()(const Log64Weight &w) const { return w.Value(); }
663 };
664
665 // Converts to log64.
666 template <>
667 struct WeightConvert<TropicalWeight, Log64Weight> {
668   Log64Weight operator()(const TropicalWeight &w) const { return w.Value(); }
669 };
670
671 template <>
672 struct WeightConvert<LogWeight, Log64Weight> {
673   Log64Weight operator()(const LogWeight &w) const { return w.Value(); }
674 };
675
676 // This function object returns random integers chosen from [0,
677 // num_random_weights). The boolean 'allow_zero' determines whether Zero() and
678 // zero divisors should be returned in the random weight generation. This is
679 // intended primary for testing.
680 template <class Weight>
681 class FloatWeightGenerate {
682  public:
683   explicit FloatWeightGenerate(
684       bool allow_zero = true,
685       const size_t num_random_weights = kNumRandomWeights)
686       : allow_zero_(allow_zero), num_random_weights_(num_random_weights) {}
687
688   Weight operator()() const {
689     const int n = rand() % (num_random_weights_ + allow_zero_);  // NOLINT
690     if (allow_zero_ && n == num_random_weights_) return Weight::Zero();
691     return Weight(n);
692   }
693
694  private:
695   // Permits Zero() and zero divisors.
696   const bool allow_zero_;
697   // Number of alternative random weights.
698   const size_t num_random_weights_;
699 };
700
701 template <class T>
702 class WeightGenerate<TropicalWeightTpl<T>>
703     : public FloatWeightGenerate<TropicalWeightTpl<T>> {
704  public:
705   using Weight = TropicalWeightTpl<T>;
706   using Generate = FloatWeightGenerate<Weight>;
707
708   explicit WeightGenerate(bool allow_zero = true,
709                           size_t num_random_weights = kNumRandomWeights)
710       : Generate(allow_zero, num_random_weights) {}
711
712   Weight operator()() const { return Weight(Generate::operator()()); }
713 };
714
715 template <class T>
716 class WeightGenerate<LogWeightTpl<T>>
717     : public FloatWeightGenerate<LogWeightTpl<T>> {
718  public:
719   using Weight = LogWeightTpl<T>;
720   using Generate = FloatWeightGenerate<Weight>;
721
722   explicit WeightGenerate(bool allow_zero = true,
723                           size_t num_random_weights = kNumRandomWeights)
724       : Generate(allow_zero, num_random_weights) {}
725
726   Weight operator()() const { return Weight(Generate::operator()()); }
727 };
728
729 // This function object returns random integers chosen from [0,
730 // num_random_weights). The boolean 'allow_zero' determines whether Zero() and
731 // zero divisors should be returned in the random weight generation. This is
732 // intended primary for testing.
733 template <class T>
734 class WeightGenerate<MinMaxWeightTpl<T>> {
735  public:
736   using Weight = MinMaxWeightTpl<T>;
737
738   explicit WeightGenerate(bool allow_zero = true,
739                           size_t num_random_weights = kNumRandomWeights)
740       : allow_zero_(allow_zero), num_random_weights_(num_random_weights) {}
741
742   Weight operator()() const {
743     const int n = (rand() %  // NOLINT
744                    (2 * num_random_weights_ + allow_zero_)) -
745                   num_random_weights_;
746     if (allow_zero_ && n == num_random_weights_) {
747       return Weight::Zero();
748     } else if (n == -num_random_weights_) {
749       return Weight::One();
750     } else {
751       return Weight(n);
752     }
753   }
754
755  private:
756   // Permits Zero() and zero divisors.
757   const bool allow_zero_;
758   // Number of alternative random weights.
759   const size_t num_random_weights_;
760 };
761
762 }  // namespace fst
763
764 #endif  // FST_LIB_FLOAT_WEIGHT_H_