937897faf4e57751d7db146c1a1878fbd08a5b4f
[platform/upstream/openfst.git] / src / include / fst / power-weight.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Cartesian power weight semiring operation definitions.
5
6 #ifndef FST_LIB_POWER_WEIGHT_H_
7 #define FST_LIB_POWER_WEIGHT_H_
8
9 #include <string>
10
11 #include <fst/tuple-weight.h>
12 #include <fst/weight.h>
13
14
15 namespace fst {
16
17 // Cartesian power semiring: W ^ n
18 //
19 // Forms:
20 //  - a left semimodule when W is a left semiring,
21 //  - a right semimodule when W is a right semiring,
22 //  - a bisemimodule when W is a semiring,
23 //    the free semimodule of rank n over W
24 // The Times operation is overloaded to provide the left and right scalar
25 // products.
26 template <class W, size_t n>
27 class PowerWeight : public TupleWeight<W, n> {
28  public:
29   using ReverseWeight = PowerWeight<typename W::ReverseWeight, n>;
30
31   PowerWeight() {}
32
33   explicit PowerWeight(const TupleWeight<W, n> &weight)
34       : TupleWeight<W, n>(weight) {}
35
36   template <class Iterator>
37   PowerWeight(Iterator begin, Iterator end) : TupleWeight<W, n>(begin, end) {}
38
39   static const PowerWeight &Zero() {
40     static const PowerWeight zero(TupleWeight<W, n>::Zero());
41     return zero;
42   }
43
44   static const PowerWeight &One() {
45     static const PowerWeight one(TupleWeight<W, n>::One());
46     return one;
47   }
48
49   static const PowerWeight &NoWeight() {
50     static const PowerWeight no_weight(TupleWeight<W, n>::NoWeight());
51     return no_weight;
52   }
53
54   static const string &Type() {
55     static string type;
56     if (type.empty()) {
57       type = W::Type() + "_^" + std::to_string(n);
58     }
59     return type;
60   }
61
62   static constexpr uint64 Properties() {
63     return W::Properties() &
64            (kLeftSemiring | kRightSemiring | kCommutative | kIdempotent);
65   }
66
67   PowerWeight Quantize(float delta = kDelta) const {
68     return PowerWeight(TupleWeight<W, n>::Quantize(delta));
69   }
70
71   ReverseWeight Reverse() const {
72     return ReverseWeight(TupleWeight<W, n>::Reverse());
73   }
74 };
75
76 // Semiring plus operation.
77 template <class W, size_t n>
78 inline PowerWeight<W, n> Plus(const PowerWeight<W, n> &w1,
79                               const PowerWeight<W, n> &w2) {
80   PowerWeight<W, n> result;
81   for (size_t i = 0; i < n; ++i) {
82     result.SetValue(i, Plus(w1.Value(i), w2.Value(i)));
83   }
84   return result;
85 }
86
87 // Semiring times operation.
88 template <class W, size_t n>
89 inline PowerWeight<W, n> Times(const PowerWeight<W, n> &w1,
90                                const PowerWeight<W, n> &w2) {
91   PowerWeight<W, n> result;
92   for (size_t i = 0; i < n; ++i) {
93     result.SetValue(i, Times(w1.Value(i), w2.Value(i)));
94   }
95   return result;
96 }
97
98 // Semiring divide operation.
99 template <class W, size_t n>
100 inline PowerWeight<W, n> Divide(const PowerWeight<W, n> &w1,
101                                 const PowerWeight<W, n> &w2,
102                                 DivideType type = DIVIDE_ANY) {
103   PowerWeight<W, n> result;
104   for (size_t i = 0; i < n; ++i) {
105     result.SetValue(i, Divide(w1.Value(i), w2.Value(i), type));
106   }
107   return result;
108 }
109
110 // Semimodule left scalar product.
111 template <class W, size_t n>
112 inline PowerWeight<W, n> Times(const W &scalar,
113                                const PowerWeight<W, n> &weight) {
114   PowerWeight<W, n> result;
115   for (size_t i = 0; i < n; ++i) {
116     result.SetValue(i, Times(scalar, weight.Value(i)));
117   }
118   return result;
119 }
120
121 // Semimodule right scalar product.
122 template <class W, size_t n>
123 inline PowerWeight<W, n> Times(const PowerWeight<W, n> &weight,
124                                const W &scalar) {
125   PowerWeight<W, n> result;
126   for (size_t i = 0; i < n; ++i) {
127     result.SetValue(i, Times(weight.Value(i), scalar));
128   }
129   return result;
130 }
131
132 // Semimodule dot product.
133 template <class W, size_t n>
134 inline W DotProduct(const PowerWeight<W, n> &w1, const PowerWeight<W, n> &w2) {
135   W result(W::Zero());
136   for (size_t i = 0; i < n; ++i) {
137     result = Plus(result, Times(w1.Value(i), w2.Value(i)));
138   }
139   return result;
140 }
141
142 // This function object generates weights over the Cartesian power of rank
143 // n over the underlying weight. This is intended primarily for testing.
144 template <class W, size_t n>
145 class WeightGenerate<PowerWeight<W, n>> {
146  public:
147   using Weight = PowerWeight<W, n>;
148   using Generate = WeightGenerate<W>;
149
150   explicit WeightGenerate(bool allow_zero = true) : generate_(allow_zero) {}
151
152   Weight operator()() const {
153     Weight result;
154     for (size_t i = 0; i < n; ++i) result.SetValue(i, generate_());
155     return result;
156   }
157
158  private:
159   Generate generate_;
160 };
161
162 }  // namespace fst
163
164 #endif  // FST_LIB_POWER_WEIGHT_H_