1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Float weight set and associated semiring operation definitions.
6 #ifndef FST_LIB_FLOAT_WEIGHT_H_
7 #define FST_LIB_FLOAT_WEIGHT_H_
18 #include <fst/weight.h>
23 // Numeric limits class.
27 static constexpr T PosInfinity() {
28 return std::numeric_limits<T>::infinity();
31 static constexpr T NegInfinity() { return -PosInfinity(); }
33 static constexpr T NumberBad() { return std::numeric_limits<T>::quiet_NaN(); }
36 // Weight class to be templated on floating-points types.
37 template <class T = float>
38 class FloatWeightTpl {
44 FloatWeightTpl(T f) : value_(f) {}
46 FloatWeightTpl(const FloatWeightTpl<T> &weight) : value_(weight.value_) {}
48 FloatWeightTpl<T> &operator=(const FloatWeightTpl<T> &weight) {
49 value_ = weight.value_;
53 std::istream &Read(std::istream &strm) { return ReadType(strm, &value_); }
55 std::ostream &Write(std::ostream &strm) const {
56 return WriteType(strm, value_);
69 const T &Value() const { return value_; }
72 void SetValue(const T &f) { value_ = f; }
74 static constexpr const char *GetPrecisionString() {
79 : sizeof(T) == 2 ? "16"
80 : sizeof(T) == 8 ? "64" : "unknown";
87 // Single-precision float weight.
88 using FloatWeight = FloatWeightTpl<float>;
91 inline bool operator==(const FloatWeightTpl<T> &w1,
92 const FloatWeightTpl<T> &w2) {
93 // Volatile qualifier thwarts over-aggressive compiler optimizations that
94 // lead to problems esp. with NaturalLess().
95 volatile T v1 = w1.Value();
96 volatile T v2 = w2.Value();
100 inline bool operator==(const FloatWeightTpl<double> &w1,
101 const FloatWeightTpl<double> &w2) {
102 return operator==<double>(w1, w2);
105 inline bool operator==(const FloatWeightTpl<float> &w1,
106 const FloatWeightTpl<float> &w2) {
107 return operator==<float>(w1, w2);
111 inline bool operator!=(const FloatWeightTpl<T> &w1,
112 const FloatWeightTpl<T> &w2) {
116 inline bool operator!=(const FloatWeightTpl<double> &w1,
117 const FloatWeightTpl<double> &w2) {
118 return operator!=<double>(w1, w2);
121 inline bool operator!=(const FloatWeightTpl<float> &w1,
122 const FloatWeightTpl<float> &w2) {
123 return operator!=<float>(w1, w2);
127 inline bool ApproxEqual(const FloatWeightTpl<T> &w1,
128 const FloatWeightTpl<T> &w2, float delta = kDelta) {
129 return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta;
133 inline std::ostream &operator<<(std::ostream &strm,
134 const FloatWeightTpl<T> &w) {
135 if (w.Value() == FloatLimits<T>::PosInfinity()) {
136 return strm << "Infinity";
137 } else if (w.Value() == FloatLimits<T>::NegInfinity()) {
138 return strm << "-Infinity";
139 } else if (w.Value() != w.Value()) { // Fails for IEEE NaN.
140 return strm << "BadNumber";
142 return strm << w.Value();
147 inline std::istream &operator>>(std::istream &strm, FloatWeightTpl<T> &w) {
150 if (s == "Infinity") {
151 w = FloatWeightTpl<T>(FloatLimits<T>::PosInfinity());
152 } else if (s == "-Infinity") {
153 w = FloatWeightTpl<T>(FloatLimits<T>::NegInfinity());
156 T f = strtod(s.c_str(), &p);
157 if (p < s.c_str() + s.size()) {
158 strm.clear(std::ios::badbit);
160 w = FloatWeightTpl<T>(f);
166 // Tropical semiring: (min, +, inf, 0).
168 class TropicalWeightTpl : public FloatWeightTpl<T> {
170 using typename FloatWeightTpl<T>::ValueType;
171 using FloatWeightTpl<T>::Value;
172 using ReverseWeight = TropicalWeightTpl<T>;
173 using Limits = FloatLimits<T>;
175 constexpr TropicalWeightTpl() : FloatWeightTpl<T>() {}
177 constexpr TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {}
179 constexpr TropicalWeightTpl(const TropicalWeightTpl<T> &weight)
180 : FloatWeightTpl<T>(weight) {}
182 static const TropicalWeightTpl<T> &Zero() {
183 static const TropicalWeightTpl zero(Limits::PosInfinity());
187 static const TropicalWeightTpl<T> &One() {
188 static const TropicalWeightTpl one(0.0F);
192 static const TropicalWeightTpl<T> &NoWeight() {
193 static const TropicalWeightTpl no_weight(Limits::NumberBad());
197 static const string &Type() {
198 static const string *const type =
199 new string(string("tropical") +
200 FloatWeightTpl<T>::GetPrecisionString());
204 bool Member() const {
205 // First part fails for IEEE NaN.
206 return Value() == Value() && Value() != Limits::NegInfinity();
209 TropicalWeightTpl<T> Quantize(float delta = kDelta) const {
210 if (!Member() || Value() == Limits::PosInfinity()) {
213 return TropicalWeightTpl<T>(floor(Value() / delta + 0.5F) * delta);
217 TropicalWeightTpl<T> Reverse() const { return *this; }
219 static constexpr uint64 Properties() {
220 return kLeftSemiring | kRightSemiring | kCommutative | kPath | kIdempotent;
224 // Single precision tropical weight.
225 using TropicalWeight = TropicalWeightTpl<float>;
228 inline TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1,
229 const TropicalWeightTpl<T> &w2) {
230 if (!w1.Member() || !w2.Member()) return TropicalWeightTpl<T>::NoWeight();
231 return w1.Value() < w2.Value() ? w1 : w2;
234 inline TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1,
235 const TropicalWeightTpl<float> &w2) {
236 return Plus<float>(w1, w2);
239 inline TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1,
240 const TropicalWeightTpl<double> &w2) {
241 return Plus<double>(w1, w2);
245 inline TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1,
246 const TropicalWeightTpl<T> &w2) {
247 using Limits = FloatLimits<T>;
248 if (!w1.Member() || !w2.Member()) return TropicalWeightTpl<T>::NoWeight();
249 const T f1 = w1.Value();
250 const T f2 = w2.Value();
251 if (f1 == Limits::PosInfinity()) {
253 } else if (f2 == Limits::PosInfinity()) {
256 return TropicalWeightTpl<T>(f1 + f2);
260 inline TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1,
261 const TropicalWeightTpl<float> &w2) {
262 return Times<float>(w1, w2);
265 inline TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1,
266 const TropicalWeightTpl<double> &w2) {
267 return Times<double>(w1, w2);
271 inline TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1,
272 const TropicalWeightTpl<T> &w2,
273 DivideType typ = DIVIDE_ANY) {
274 using Limits = FloatLimits<T>;
275 if (!w1.Member() || !w2.Member()) return TropicalWeightTpl<T>::NoWeight();
276 const T f1 = w1.Value();
277 const T f2 = w2.Value();
278 if (f2 == Limits::PosInfinity()) {
279 return Limits::NumberBad();
280 } else if (f1 == Limits::PosInfinity()) {
281 return Limits::PosInfinity();
283 return TropicalWeightTpl<T>(f1 - f2);
287 inline TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1,
288 const TropicalWeightTpl<float> &w2,
289 DivideType typ = DIVIDE_ANY) {
290 return Divide<float>(w1, w2, typ);
293 inline TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1,
294 const TropicalWeightTpl<double> &w2,
295 DivideType typ = DIVIDE_ANY) {
296 return Divide<double>(w1, w2, typ);
300 inline TropicalWeightTpl<T> Power(const TropicalWeightTpl<T> &weight,
302 return TropicalWeightTpl<T>(weight.Value() * scalar);
305 // Log semiring: (log(e^-x + e^-y), +, inf, 0).
307 class LogWeightTpl : public FloatWeightTpl<T> {
309 using typename FloatWeightTpl<T>::ValueType;
310 using FloatWeightTpl<T>::Value;
311 using ReverseWeight = LogWeightTpl;
312 using Limits = FloatLimits<T>;
314 constexpr LogWeightTpl() : FloatWeightTpl<T>() {}
316 constexpr LogWeightTpl(T f) : FloatWeightTpl<T>(f) {}
318 constexpr LogWeightTpl(const LogWeightTpl<T> &weight)
319 : FloatWeightTpl<T>(weight) {}
321 static const LogWeightTpl &Zero() {
322 static const LogWeightTpl zero(Limits::PosInfinity());
326 static const LogWeightTpl &One() {
327 static const LogWeightTpl one(0.0F);
331 static const LogWeightTpl &NoWeight() {
332 static const LogWeightTpl no_weight(Limits::NumberBad());
336 static const string &Type() {
337 static const string *const type =
338 new string(string("log") + FloatWeightTpl<T>::GetPrecisionString());
342 bool Member() const {
343 // First part fails for IEEE NaN.
344 return Value() == Value() && Value() != Limits::NegInfinity();
347 LogWeightTpl<T> Quantize(float delta = kDelta) const {
348 if (!Member() || Value() == Limits::PosInfinity()) {
351 return LogWeightTpl<T>(floor(Value() / delta + 0.5F) * delta);
355 LogWeightTpl<T> Reverse() const { return *this; }
357 static constexpr uint64 Properties() {
358 return kLeftSemiring | kRightSemiring | kCommutative;
362 // Single-precision log weight.
363 using LogWeight = LogWeightTpl<float>;
365 // Double-precision log weight.
366 using Log64Weight = LogWeightTpl<double>;
370 // -log(e^-x + e^-y) = x - LogPosExp(y - x), assuming x >= 0.0.
371 inline double LogPosExp(double x) { return log1p(exp(-x)); }
373 // -log(e^-x - e^-y) = x - LogNegExp(y - x), assuming x > 0.0.
374 inline double LogNegExp(double x) { return log1p(-exp(-x)); }
376 // a +_log b = -log(e^-a + e^-b) = KahanLogSum(a, b, ...).
377 // Kahan compensated summation provides an error bound that is
378 // independent of the number of addends. Assumes b >= a;
379 // c is the compensation.
380 inline double KahanLogSum(double a, double b, double *c) {
381 double y = -LogPosExp(b - a) - *c;
387 // a -_log b = -log(e^-a - e^-b) = KahanLogDiff(a, b, ...).
388 // Kahan compensated summation provides an error bound that is
389 // independent of the number of addends. Assumes b > a;
390 // c is the compensation.
391 inline double KahanLogDiff(double a, double b, double *c) {
392 double y = -LogNegExp(b - a) - *c;
398 } // namespace internal
401 inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1,
402 const LogWeightTpl<T> &w2) {
403 using Limits = FloatLimits<T>;
404 const T f1 = w1.Value();
405 const T f2 = w2.Value();
406 if (f1 == Limits::PosInfinity()) {
408 } else if (f2 == Limits::PosInfinity()) {
410 } else if (f1 > f2) {
411 return LogWeightTpl<T>(f2 - internal::LogPosExp(f1 - f2));
413 return LogWeightTpl<T>(f1 - internal::LogPosExp(f2 - f1));
417 inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1,
418 const LogWeightTpl<float> &w2) {
419 return Plus<float>(w1, w2);
422 inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1,
423 const LogWeightTpl<double> &w2) {
424 return Plus<double>(w1, w2);
428 inline LogWeightTpl<T> Times(const LogWeightTpl<T> &w1,
429 const LogWeightTpl<T> &w2) {
430 using Limits = FloatLimits<T>;
431 if (!w1.Member() || !w2.Member()) return LogWeightTpl<T>::NoWeight();
432 const T f1 = w1.Value();
433 const T f2 = w2.Value();
434 if (f1 == Limits::PosInfinity()) {
436 } else if (f2 == Limits::PosInfinity()) {
439 return LogWeightTpl<T>(f1 + f2);
443 inline LogWeightTpl<float> Times(const LogWeightTpl<float> &w1,
444 const LogWeightTpl<float> &w2) {
445 return Times<float>(w1, w2);
448 inline LogWeightTpl<double> Times(const LogWeightTpl<double> &w1,
449 const LogWeightTpl<double> &w2) {
450 return Times<double>(w1, w2);
454 inline LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1,
455 const LogWeightTpl<T> &w2,
456 DivideType typ = DIVIDE_ANY) {
457 using Limits = FloatLimits<T>;
458 if (!w1.Member() || !w2.Member()) return LogWeightTpl<T>::NoWeight();
459 const T f1 = w1.Value();
460 const T f2 = w2.Value();
461 if (f2 == Limits::PosInfinity()) {
462 return Limits::NumberBad();
463 } else if (f1 == Limits::PosInfinity()) {
464 return Limits::PosInfinity();
466 return LogWeightTpl<T>(f1 - f2);
470 inline LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1,
471 const LogWeightTpl<float> &w2,
472 DivideType typ = DIVIDE_ANY) {
473 return Divide<float>(w1, w2, typ);
476 inline LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1,
477 const LogWeightTpl<double> &w2,
478 DivideType typ = DIVIDE_ANY) {
479 return Divide<double>(w1, w2, typ);
483 inline LogWeightTpl<T> Power(const LogWeightTpl<T> &weight, T scalar) {
484 return LogWeightTpl<T>(weight.Value() * scalar);
487 // Specialization using the Kahan compensated summation.
489 class Adder<LogWeightTpl<T>> {
491 using Weight = LogWeightTpl<T>;
493 explicit Adder(Weight w = Weight::Zero())
497 Weight Add(const Weight &w) {
498 using Limits = FloatLimits<T>;
499 const T f = w.Value();
500 if (f == Limits::PosInfinity()) {
502 } else if (sum_ == Limits::PosInfinity()) {
505 } else if (f > sum_) {
506 sum_ = internal::KahanLogSum(sum_, f, &c_);
508 sum_ = internal::KahanLogSum(f, sum_, &c_);
513 Weight Sum() { return Weight(sum_); }
515 void Reset(Weight w = Weight::Zero()) {
522 double c_; // Kahan compensation.
525 // MinMax semiring: (min, max, inf, -inf).
527 class MinMaxWeightTpl : public FloatWeightTpl<T> {
529 using typename FloatWeightTpl<T>::ValueType;
530 using FloatWeightTpl<T>::Value;
531 using ReverseWeight = MinMaxWeightTpl<T>;
532 using Limits = FloatLimits<T>;
534 MinMaxWeightTpl() : FloatWeightTpl<T>() {}
536 MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {}
538 MinMaxWeightTpl(const MinMaxWeightTpl<T> &weight)
539 : FloatWeightTpl<T>(weight) {}
541 static const MinMaxWeightTpl &Zero() {
542 static const MinMaxWeightTpl zero(Limits::PosInfinity());
546 static const MinMaxWeightTpl &One() {
547 static const MinMaxWeightTpl one(Limits::NegInfinity());
551 static const MinMaxWeightTpl &NoWeight() {
552 static const MinMaxWeightTpl no_weight(Limits::NumberBad());
556 static const string &Type() {
557 static const string *const type =
558 new string(string("minmax") + FloatWeightTpl<T>::GetPrecisionString());
562 // Fails for IEEE NaN.
563 bool Member() const { return Value() == Value(); }
565 MinMaxWeightTpl<T> Quantize(float delta = kDelta) const {
566 // If one of infinities, or a NaN.
568 Value() == Limits::NegInfinity() || Value() == Limits::PosInfinity()) {
571 return MinMaxWeightTpl<T>(floor(Value() / delta + 0.5F) * delta);
575 MinMaxWeightTpl<T> Reverse() const { return *this; }
577 static constexpr uint64 Properties() {
578 return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath;
582 // Single-precision min-max weight.
583 using MinMaxWeight = MinMaxWeightTpl<float>;
587 inline MinMaxWeightTpl<T> Plus(const MinMaxWeightTpl<T> &w1,
588 const MinMaxWeightTpl<T> &w2) {
589 if (!w1.Member() || !w2.Member()) return MinMaxWeightTpl<T>::NoWeight();
590 return w1.Value() < w2.Value() ? w1 : w2;
593 inline MinMaxWeightTpl<float> Plus(const MinMaxWeightTpl<float> &w1,
594 const MinMaxWeightTpl<float> &w2) {
595 return Plus<float>(w1, w2);
598 inline MinMaxWeightTpl<double> Plus(const MinMaxWeightTpl<double> &w1,
599 const MinMaxWeightTpl<double> &w2) {
600 return Plus<double>(w1, w2);
605 inline MinMaxWeightTpl<T> Times(const MinMaxWeightTpl<T> &w1,
606 const MinMaxWeightTpl<T> &w2) {
607 if (!w1.Member() || !w2.Member()) return MinMaxWeightTpl<T>::NoWeight();
608 return w1.Value() >= w2.Value() ? w1 : w2;
611 inline MinMaxWeightTpl<float> Times(const MinMaxWeightTpl<float> &w1,
612 const MinMaxWeightTpl<float> &w2) {
613 return Times<float>(w1, w2);
616 inline MinMaxWeightTpl<double> Times(const MinMaxWeightTpl<double> &w1,
617 const MinMaxWeightTpl<double> &w2) {
618 return Times<double>(w1, w2);
621 // Defined only for special cases.
623 inline MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1,
624 const MinMaxWeightTpl<T> &w2,
625 DivideType typ = DIVIDE_ANY) {
626 if (!w1.Member() || !w2.Member()) return MinMaxWeightTpl<T>::NoWeight();
627 // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2.
628 return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::NumberBad();
631 inline MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1,
632 const MinMaxWeightTpl<float> &w2,
633 DivideType typ = DIVIDE_ANY) {
634 return Divide<float>(w1, w2, typ);
637 inline MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1,
638 const MinMaxWeightTpl<double> &w2,
639 DivideType typ = DIVIDE_ANY) {
640 return Divide<double>(w1, w2, typ);
643 // Converts to tropical.
645 struct WeightConvert<LogWeight, TropicalWeight> {
646 TropicalWeight operator()(const LogWeight &w) const { return w.Value(); }
650 struct WeightConvert<Log64Weight, TropicalWeight> {
651 TropicalWeight operator()(const Log64Weight &w) const { return w.Value(); }
656 struct WeightConvert<TropicalWeight, LogWeight> {
657 LogWeight operator()(const TropicalWeight &w) const { return w.Value(); }
661 struct WeightConvert<Log64Weight, LogWeight> {
662 LogWeight operator()(const Log64Weight &w) const { return w.Value(); }
665 // Converts to log64.
667 struct WeightConvert<TropicalWeight, Log64Weight> {
668 Log64Weight operator()(const TropicalWeight &w) const { return w.Value(); }
672 struct WeightConvert<LogWeight, Log64Weight> {
673 Log64Weight operator()(const LogWeight &w) const { return w.Value(); }
676 // This function object returns random integers chosen from [0,
677 // num_random_weights). The boolean 'allow_zero' determines whether Zero() and
678 // zero divisors should be returned in the random weight generation. This is
679 // intended primary for testing.
680 template <class Weight>
681 class FloatWeightGenerate {
683 explicit FloatWeightGenerate(
684 bool allow_zero = true,
685 const size_t num_random_weights = kNumRandomWeights)
686 : allow_zero_(allow_zero), num_random_weights_(num_random_weights) {}
688 Weight operator()() const {
689 const int n = rand() % (num_random_weights_ + allow_zero_); // NOLINT
690 if (allow_zero_ && n == num_random_weights_) return Weight::Zero();
695 // Permits Zero() and zero divisors.
696 const bool allow_zero_;
697 // Number of alternative random weights.
698 const size_t num_random_weights_;
702 class WeightGenerate<TropicalWeightTpl<T>>
703 : public FloatWeightGenerate<TropicalWeightTpl<T>> {
705 using Weight = TropicalWeightTpl<T>;
706 using Generate = FloatWeightGenerate<Weight>;
708 explicit WeightGenerate(bool allow_zero = true,
709 size_t num_random_weights = kNumRandomWeights)
710 : Generate(allow_zero, num_random_weights) {}
712 Weight operator()() const { return Weight(Generate::operator()()); }
716 class WeightGenerate<LogWeightTpl<T>>
717 : public FloatWeightGenerate<LogWeightTpl<T>> {
719 using Weight = LogWeightTpl<T>;
720 using Generate = FloatWeightGenerate<Weight>;
722 explicit WeightGenerate(bool allow_zero = true,
723 size_t num_random_weights = kNumRandomWeights)
724 : Generate(allow_zero, num_random_weights) {}
726 Weight operator()() const { return Weight(Generate::operator()()); }
729 // This function object returns random integers chosen from [0,
730 // num_random_weights). The boolean 'allow_zero' determines whether Zero() and
731 // zero divisors should be returned in the random weight generation. This is
732 // intended primary for testing.
734 class WeightGenerate<MinMaxWeightTpl<T>> {
736 using Weight = MinMaxWeightTpl<T>;
738 explicit WeightGenerate(bool allow_zero = true,
739 size_t num_random_weights = kNumRandomWeights)
740 : allow_zero_(allow_zero), num_random_weights_(num_random_weights) {}
742 Weight operator()() const {
743 const int n = (rand() % // NOLINT
744 (2 * num_random_weights_ + allow_zero_)) -
746 if (allow_zero_ && n == num_random_weights_) {
747 return Weight::Zero();
748 } else if (n == -num_random_weights_) {
749 return Weight::One();
756 // Permits Zero() and zero divisors.
757 const bool allow_zero_;
758 // Number of alternative random weights.
759 const size_t num_random_weights_;
764 #endif // FST_LIB_FLOAT_WEIGHT_H_