1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // LogWeight along with sign information that represents the value X in the
5 // linear domain as <sign(X), -ln(|X|)>
7 // The sign is a TropicalWeight:
8 // positive, TropicalWeight.Value() > 0.0, recommended value 1.0
9 // negative, TropicalWeight.Value() <= 0.0, recommended value -1.0
11 #ifndef FST_SIGNED_LOG_WEIGHT_H_
12 #define FST_SIGNED_LOG_WEIGHT_H_
16 #include <fst/float-weight.h>
17 #include <fst/pair-weight.h>
18 #include <fst/product-weight.h>
23 class SignedLogWeightTpl : public PairWeight<TropicalWeight, LogWeightTpl<T>> {
25 using X1 = TropicalWeight;
26 using X2 = LogWeightTpl<T>;
27 using ReverseWeight = SignedLogWeightTpl;
29 using PairWeight<X1, X2>::Value1;
30 using PairWeight<X1, X2>::Value2;
32 SignedLogWeightTpl() : PairWeight<X1, X2>() {}
34 SignedLogWeightTpl(const SignedLogWeightTpl &w) : PairWeight<X1, X2>(w) {}
36 explicit SignedLogWeightTpl(const PairWeight<X1, X2> &w)
37 : PairWeight<X1, X2>(w) {}
39 SignedLogWeightTpl(const X1 &x1, const X2 &x2) : PairWeight<X1, X2>(x1, x2) {}
41 static const SignedLogWeightTpl &Zero() {
42 static const SignedLogWeightTpl zero(X1(1.0), X2::Zero());
46 static const SignedLogWeightTpl &One() {
47 static const SignedLogWeightTpl one(X1(1.0), X2::One());
51 static const SignedLogWeightTpl &NoWeight() {
52 static const SignedLogWeightTpl no_weight(X1(1.0), X2::NoWeight());
56 static const string &Type() {
57 static const string *const type =
58 new string("signed_log_" + X1::Type() + "_" + X2::Type());
62 SignedLogWeightTpl Quantize(float delta = kDelta) const {
63 return SignedLogWeightTpl(PairWeight<X1, X2>::Quantize(delta));
66 ReverseWeight Reverse() const {
67 return SignedLogWeightTpl(PairWeight<X1, X2>::Reverse());
70 bool Member() const { return PairWeight<X1, X2>::Member(); }
72 // Neither idempotent nor path.
73 static constexpr uint64 Properties() {
74 return kLeftSemiring | kRightSemiring | kCommutative;
79 if (Value2() == X2::Zero() || Value1().Value() > 0.0) {
80 h1 = TropicalWeight(1.0).Hash();
82 h1 = TropicalWeight(-1.0).Hash();
84 size_t h2 = Value2().Hash();
85 static constexpr int lshift = 5;
86 static constexpr int rshift = CHAR_BIT * sizeof(size_t) - 5;
87 return h1 << lshift ^ h1 >> rshift ^ h2;
92 inline SignedLogWeightTpl<T> Plus(const SignedLogWeightTpl<T> &w1,
93 const SignedLogWeightTpl<T> &w2) {
94 using X1 = TropicalWeight;
95 using X2 = LogWeightTpl<T>;
96 if (!w1.Member() || !w2.Member()) return SignedLogWeightTpl<T>::NoWeight();
97 const auto s1 = w1.Value1().Value() > 0.0;
98 const auto s2 = w2.Value1().Value() > 0.0;
99 const bool equal = (s1 == s2);
100 const auto f1 = w1.Value2().Value();
101 const auto f2 = w2.Value2().Value();
102 if (f1 == FloatLimits<T>::PosInfinity()) {
104 } else if (f2 == FloatLimits<T>::PosInfinity()) {
106 } else if (f1 == f2) {
108 return SignedLogWeightTpl<T>(X1(w1.Value1()), X2(f2 - log(2.0F)));
110 return SignedLogWeightTpl<T>::Zero();
112 } else if (f1 > f2) {
114 return SignedLogWeightTpl<T>(X1(w1.Value1()),
115 X2(f2 - internal::LogPosExp(f1 - f2)));
117 return SignedLogWeightTpl<T>(X1(w2.Value1()),
118 X2((f2 - internal::LogNegExp(f1 - f2))));
122 return SignedLogWeightTpl<T>(X1(w2.Value1()),
123 X2((f1 - internal::LogPosExp(f2 - f1))));
125 return SignedLogWeightTpl<T>(X1(w1.Value1()),
126 X2((f1 - internal::LogNegExp(f2 - f1))));
132 inline SignedLogWeightTpl<T> Minus(const SignedLogWeightTpl<T> &w1,
133 const SignedLogWeightTpl<T> &w2) {
134 SignedLogWeightTpl<T> minus_w2(-w2.Value1().Value(), w2.Value2());
135 return Plus(w1, minus_w2);
139 inline SignedLogWeightTpl<T> Times(const SignedLogWeightTpl<T> &w1,
140 const SignedLogWeightTpl<T> &w2) {
141 using X2 = LogWeightTpl<T>;
142 if (!w1.Member() || !w2.Member()) return SignedLogWeightTpl<T>::NoWeight();
143 const auto s1 = w1.Value1().Value() > 0.0;
144 const auto s2 = w2.Value1().Value() > 0.0;
145 const auto f1 = w1.Value2().Value();
146 const auto f2 = w2.Value2().Value();
148 return SignedLogWeightTpl<T>(TropicalWeight(1.0), X2(f1 + f2));
150 return SignedLogWeightTpl<T>(TropicalWeight(-1.0), X2(f1 + f2));
155 inline SignedLogWeightTpl<T> Divide(const SignedLogWeightTpl<T> &w1,
156 const SignedLogWeightTpl<T> &w2,
157 DivideType typ = DIVIDE_ANY) {
158 using X2 = LogWeightTpl<T>;
159 if (!w1.Member() || !w2.Member()) return SignedLogWeightTpl<T>::NoWeight();
160 const auto s1 = w1.Value1().Value() > 0.0;
161 const auto s2 = w2.Value1().Value() > 0.0;
162 const auto f1 = w1.Value2().Value();
163 const auto f2 = w2.Value2().Value();
164 if (f2 == FloatLimits<T>::PosInfinity()) {
165 return SignedLogWeightTpl<T>(TropicalWeight(1.0),
166 X2(FloatLimits<T>::NumberBad()));
167 } else if (f1 == FloatLimits<T>::PosInfinity()) {
168 return SignedLogWeightTpl<T>(TropicalWeight(1.0),
169 X2(FloatLimits<T>::PosInfinity()));
170 } else if (s1 == s2) {
171 return SignedLogWeightTpl<T>(TropicalWeight(1.0), X2(f1 - f2));
173 return SignedLogWeightTpl<T>(TropicalWeight(-1.0), X2(f1 - f2));
178 inline bool ApproxEqual(const SignedLogWeightTpl<T> &w1,
179 const SignedLogWeightTpl<T> &w2, float delta = kDelta) {
180 const auto s1 = w1.Value1().Value() > 0.0;
181 const auto s2 = w2.Value1().Value() > 0.0;
183 return ApproxEqual(w1.Value2(), w2.Value2(), delta);
185 return w1.Value2() == LogWeightTpl<T>::Zero() &&
186 w2.Value2() == LogWeightTpl<T>::Zero();
191 inline bool operator==(const SignedLogWeightTpl<T> &w1,
192 const SignedLogWeightTpl<T> &w2) {
193 const auto s1 = w1.Value1().Value() > 0.0;
194 const auto s2 = w2.Value1().Value() > 0.0;
196 return w1.Value2() == w2.Value2();
198 return (w1.Value2() == LogWeightTpl<T>::Zero()) &&
199 (w2.Value2() == LogWeightTpl<T>::Zero());
203 // Single-precision signed-log weight.
204 using SignedLogWeight = SignedLogWeightTpl<float>;
206 // Double-precision signed-log weight.
207 using SignedLog64Weight = SignedLogWeightTpl<double>;
209 template <class W1, class W2>
210 bool SignedLogConvertCheck(W1 weight) {
211 if (weight.Value1().Value() < 0.0) {
212 FSTERROR() << "WeightConvert: Can't convert weight " << weight
213 << " from " << W1::Type() << " to " << W2::Type();
219 // Specialization using the Kahan compensated summation
221 class Adder<SignedLogWeightTpl<T>> {
223 using Weight = SignedLogWeightTpl<T>;
224 using X1 = TropicalWeight;
225 using X2 = LogWeightTpl<T>;
227 explicit Adder(Weight w = Weight::Zero())
228 : ssum_(w.Value1().Value() > 0.0),
229 sum_(w.Value2().Value()),
232 Weight Add(const Weight &w) {
233 const auto sw = w.Value1().Value() > 0.0;
234 const auto f = w.Value2().Value();
235 const bool equal = (ssum_ == sw);
237 if (!Sum().Member() || f == FloatLimits<T>::PosInfinity()) {
239 } else if (!w.Member() || sum_ == FloatLimits<T>::PosInfinity()) {
243 } else if (f == sum_) {
245 sum_ = internal::KahanLogSum(sum_, f, &c_);
247 sum_ = FloatLimits<T>::PosInfinity();
251 } else if (f > sum_) {
253 sum_ = internal::KahanLogSum(sum_, f, &c_);
255 sum_ = internal::KahanLogDiff(sum_, f, &c_);
259 sum_ = internal::KahanLogSum(f, sum_, &c_);
261 sum_ = internal::KahanLogDiff(f, sum_, &c_);
268 Weight Sum() { return Weight(X1(ssum_ ? 1.0 : -1.0), X2(sum_)); }
270 void Reset(Weight w = Weight::Zero()) {
271 ssum_ = w.Value1().Value() > 0.0;
272 sum_ = w.Value2().Value();
277 bool ssum_; // true iff sign of sum is positive
278 double sum_; // unsigned sum
279 double c_; // Kahan compensation
282 // Converts to tropical.
284 struct WeightConvert<SignedLogWeight, TropicalWeight> {
285 TropicalWeight operator()(const SignedLogWeight &weight) const {
286 if (!SignedLogConvertCheck<SignedLogWeight, TropicalWeight>(weight)) {
287 return TropicalWeight::NoWeight();
289 return TropicalWeight(weight.Value2().Value());
294 struct WeightConvert<SignedLog64Weight, TropicalWeight> {
295 TropicalWeight operator()(const SignedLog64Weight &weight) const {
296 if (!SignedLogConvertCheck<SignedLog64Weight, TropicalWeight>(weight)) {
297 return TropicalWeight::NoWeight();
299 return TropicalWeight(weight.Value2().Value());
305 struct WeightConvert<SignedLogWeight, LogWeight> {
306 LogWeight operator()(const SignedLogWeight &weight) const {
307 if (!SignedLogConvertCheck<SignedLogWeight, LogWeight>(weight)) {
308 return LogWeight::NoWeight();
310 return LogWeight(weight.Value2().Value());
315 struct WeightConvert<SignedLog64Weight, LogWeight> {
316 LogWeight operator()(const SignedLog64Weight &weight) const {
317 if (!SignedLogConvertCheck<SignedLog64Weight, LogWeight>(weight)) {
318 return LogWeight::NoWeight();
320 return LogWeight(weight.Value2().Value());
324 // Converts to log64.
326 struct WeightConvert<SignedLogWeight, Log64Weight> {
327 Log64Weight operator()(const SignedLogWeight &weight) const {
328 if (!SignedLogConvertCheck<SignedLogWeight, Log64Weight>(weight)) {
329 return Log64Weight::NoWeight();
331 return Log64Weight(weight.Value2().Value());
336 struct WeightConvert<SignedLog64Weight, Log64Weight> {
337 Log64Weight operator()(const SignedLog64Weight &weight) const {
338 if (!SignedLogConvertCheck<SignedLog64Weight, Log64Weight>(weight)) {
339 return Log64Weight::NoWeight();
341 return Log64Weight(weight.Value2().Value());
345 // Converts to signed log.
347 struct WeightConvert<TropicalWeight, SignedLogWeight> {
348 SignedLogWeight operator()(const TropicalWeight &weight) const {
349 return SignedLogWeight(1.0, weight.Value());
354 struct WeightConvert<LogWeight, SignedLogWeight> {
355 SignedLogWeight operator()(const LogWeight &weight) const {
356 return SignedLogWeight(1.0, weight.Value());
361 struct WeightConvert<Log64Weight, SignedLogWeight> {
362 SignedLogWeight operator()(const Log64Weight &weight) const {
363 return SignedLogWeight(1.0, weight.Value());
368 struct WeightConvert<SignedLog64Weight, SignedLogWeight> {
369 SignedLogWeight operator()(const SignedLog64Weight &weight) const {
370 return SignedLogWeight(weight.Value1(), weight.Value2().Value());
374 // Converts to signed log64.
376 struct WeightConvert<TropicalWeight, SignedLog64Weight> {
377 SignedLog64Weight operator()(const TropicalWeight &weight) const {
378 return SignedLog64Weight(1.0, weight.Value());
383 struct WeightConvert<LogWeight, SignedLog64Weight> {
384 SignedLog64Weight operator()(const LogWeight &weight) const {
385 return SignedLog64Weight(1.0, weight.Value());
390 struct WeightConvert<Log64Weight, SignedLog64Weight> {
391 SignedLog64Weight operator()(const Log64Weight &weight) const {
392 return SignedLog64Weight(1.0, weight.Value());
397 struct WeightConvert<SignedLogWeight, SignedLog64Weight> {
398 SignedLog64Weight operator()(const SignedLogWeight &weight) const {
399 return SignedLog64Weight(weight.Value1(), weight.Value2().Value());
403 // This function object returns SignedLogWeightTpl<T>'s that are random integers
404 // chosen from [0, num_random_weights) times a random sign. This is intended
405 // primarily for testing.
407 class WeightGenerate<SignedLogWeightTpl<T>> {
409 using Weight = SignedLogWeightTpl<T>;
410 using X1 = typename Weight::X1;
411 using X2 = typename Weight::X2;
413 explicit WeightGenerate(bool allow_zero = true,
414 size_t num_random_weights = kNumRandomWeights)
415 : allow_zero_(allow_zero), num_random_weights_(num_random_weights) {}
417 Weight operator()() const {
418 static const X1 negative_one(-1.0);
419 static const X1 positive_one(+1.0);
420 const int m = rand() % 2; // NOLINT
421 const int n = rand() % (num_random_weights_ + allow_zero_); // NOLINT
422 return Weight((m == 0) ? negative_one : positive_one,
423 (allow_zero_ && n == num_random_weights_) ?
428 // Permits Zero() and zero divisors.
429 const bool allow_zero_;
430 // Number of alternative random weights.
431 const size_t num_random_weights_;
436 #endif // FST_SIGNED_LOG_WEIGHT_H_