Imported Upstream version 1.6.4
[platform/upstream/openfst.git] / src / include / fst / power-weight.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Cartesian power weight semiring operation definitions.
5
6 #ifndef FST_LIB_POWER_WEIGHT_H_
7 #define FST_LIB_POWER_WEIGHT_H_
8
9 #include <string>
10
11 #include <fst/tuple-weight.h>
12 #include <fst/weight.h>
13
14
15 namespace fst {
16
17 // Cartesian power semiring: W ^ n
18 //
19 // Forms:
20 //  - a left semimodule when W is a left semiring,
21 //  - a right semimodule when W is a right semiring,
22 //  - a bisemimodule when W is a semiring,
23 //    the free semimodule of rank n over W
24 // The Times operation is overloaded to provide the left and right scalar
25 // products.
26 template <class W, size_t n>
27 class PowerWeight : public TupleWeight<W, n> {
28  public:
29   using ReverseWeight = PowerWeight<typename W::ReverseWeight, n>;
30
31   PowerWeight() {}
32
33   explicit PowerWeight(const TupleWeight<W, n> &weight)
34       : TupleWeight<W, n>(weight) {}
35
36   template <class Iterator>
37   PowerWeight(Iterator begin, Iterator end) : TupleWeight<W, n>(begin, end) {}
38
39   static const PowerWeight &Zero() {
40     static const PowerWeight zero(TupleWeight<W, n>::Zero());
41     return zero;
42   }
43
44   static const PowerWeight &One() {
45     static const PowerWeight one(TupleWeight<W, n>::One());
46     return one;
47   }
48
49   static const PowerWeight &NoWeight() {
50     static const PowerWeight no_weight(TupleWeight<W, n>::NoWeight());
51     return no_weight;
52   }
53
54   static const string &Type() {
55     static const string *const type =
56         new string(W::Type() + "_^" + std::to_string(n));
57     return *type;
58   }
59
60   static constexpr uint64 Properties() {
61     return W::Properties() &
62            (kLeftSemiring | kRightSemiring | kCommutative | kIdempotent);
63   }
64
65   PowerWeight Quantize(float delta = kDelta) const {
66     return PowerWeight(TupleWeight<W, n>::Quantize(delta));
67   }
68
69   ReverseWeight Reverse() const {
70     return ReverseWeight(TupleWeight<W, n>::Reverse());
71   }
72 };
73
74 // Semiring plus operation.
75 template <class W, size_t n>
76 inline PowerWeight<W, n> Plus(const PowerWeight<W, n> &w1,
77                               const PowerWeight<W, n> &w2) {
78   PowerWeight<W, n> result;
79   for (size_t i = 0; i < n; ++i) {
80     result.SetValue(i, Plus(w1.Value(i), w2.Value(i)));
81   }
82   return result;
83 }
84
85 // Semiring times operation.
86 template <class W, size_t n>
87 inline PowerWeight<W, n> Times(const PowerWeight<W, n> &w1,
88                                const PowerWeight<W, n> &w2) {
89   PowerWeight<W, n> result;
90   for (size_t i = 0; i < n; ++i) {
91     result.SetValue(i, Times(w1.Value(i), w2.Value(i)));
92   }
93   return result;
94 }
95
96 // Semiring divide operation.
97 template <class W, size_t n>
98 inline PowerWeight<W, n> Divide(const PowerWeight<W, n> &w1,
99                                 const PowerWeight<W, n> &w2,
100                                 DivideType type = DIVIDE_ANY) {
101   PowerWeight<W, n> result;
102   for (size_t i = 0; i < n; ++i) {
103     result.SetValue(i, Divide(w1.Value(i), w2.Value(i), type));
104   }
105   return result;
106 }
107
108 // Semimodule left scalar product.
109 template <class W, size_t n>
110 inline PowerWeight<W, n> Times(const W &scalar,
111                                const PowerWeight<W, n> &weight) {
112   PowerWeight<W, n> result;
113   for (size_t i = 0; i < n; ++i) {
114     result.SetValue(i, Times(scalar, weight.Value(i)));
115   }
116   return result;
117 }
118
119 // Semimodule right scalar product.
120 template <class W, size_t n>
121 inline PowerWeight<W, n> Times(const PowerWeight<W, n> &weight,
122                                const W &scalar) {
123   PowerWeight<W, n> result;
124   for (size_t i = 0; i < n; ++i) {
125     result.SetValue(i, Times(weight.Value(i), scalar));
126   }
127   return result;
128 }
129
130 // Semimodule dot product.
131 template <class W, size_t n>
132 inline W DotProduct(const PowerWeight<W, n> &w1, const PowerWeight<W, n> &w2) {
133   W result(W::Zero());
134   for (size_t i = 0; i < n; ++i) {
135     result = Plus(result, Times(w1.Value(i), w2.Value(i)));
136   }
137   return result;
138 }
139
140 // This function object generates weights over the Cartesian power of rank
141 // n over the underlying weight. This is intended primarily for testing.
142 template <class W, size_t n>
143 class WeightGenerate<PowerWeight<W, n>> {
144  public:
145   using Weight = PowerWeight<W, n>;
146   using Generate = WeightGenerate<W>;
147
148   explicit WeightGenerate(bool allow_zero = true) : generate_(allow_zero) {}
149
150   Weight operator()() const {
151     Weight result;
152     for (size_t i = 0; i < n; ++i) result.SetValue(i, generate_());
153     return result;
154   }
155
156  private:
157   Generate generate_;
158 };
159
160 }  // namespace fst
161
162 #endif  // FST_LIB_POWER_WEIGHT_H_