1b6f83c88cef3cb1121af11777f2c8d79c030788
[platform/upstream/openfst.git] / src / include / fst / sparse-power-weight.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Cartesian power weight semiring operation definitions, using
5 // SparseTupleWeight as underlying representation.
6
7 #ifndef FST_LIB_SPARSE_POWER_WEIGHT_H_
8 #define FST_LIB_SPARSE_POWER_WEIGHT_H_
9
10 #include <climits>
11 #include <string>
12
13 #include <fst/sparse-tuple-weight.h>
14 #include <fst/weight.h>
15
16
17 namespace fst {
18
19 // Below SparseTupleWeight*Mapper are used in conjunction with
20 // SparseTupleWeightMap to compute the respective semiring operations
21 template <class W, class K>
22 struct SparseTupleWeightPlusMapper {
23   W Map(const K &k, const W &v1, const W &v2) const { return Plus(v1, v2); }
24 };
25
26 template <class W, class K>
27 struct SparseTupleWeightTimesMapper {
28   W Map(const K &k, const W &v1, const W &v2) const { return Times(v1, v2); }
29 };
30
31 template <class W, class K>
32 struct SparseTupleWeightDivideMapper {
33   const DivideType type;
34
35   explicit SparseTupleWeightDivideMapper(DivideType type_) : type(type_) {}
36
37   W Map(const K &k, const W &v1, const W &v2) const {
38     return Divide(v1, v2, type);
39   }
40 };
41
42 template <class W, class K>
43 struct SparseTupleWeightApproxMapper {
44   const float delta;
45
46   explicit SparseTupleWeightApproxMapper(float delta_ = kDelta)
47       : delta(delta_) {}
48
49   W Map(const K &k, const W &v1, const W &v2) const {
50     return ApproxEqual(v1, v2, delta) ? W::One() : W::Zero();
51   }
52 };
53
54 // Sparse cartesian power semiring: W ^ n
55 //
56 // Forms:
57 //
58 //  - a left semimodule when W is a left semiring,
59 //  - a right semimodule when W is a right semiring,
60 //  - a bisemimodule when W is a semiring,
61 //    the free semimodule of rank n over W
62 //
63 // The Times operation is overloaded to provide the left and right scalar
64 // products.
65 //
66 // K is the key value type. kNoKey (-1) is reserved for internal use
67 template <class W, class K = int>
68 class SparsePowerWeight : public SparseTupleWeight<W, K> {
69  public:
70   using ReverseWeight = SparsePowerWeight<typename W::ReverseWeight, K>;
71
72   SparsePowerWeight() {}
73
74   explicit SparsePowerWeight(const SparseTupleWeight<W, K> &weight)
75       : SparseTupleWeight<W, K>(weight) {}
76
77   template <class Iterator>
78   SparsePowerWeight(Iterator begin, Iterator end)
79       : SparseTupleWeight<W, K>(begin, end) {}
80
81   SparsePowerWeight(const K &key, const W &weight)
82       : SparseTupleWeight<W, K>(key, weight) {}
83
84   static const SparsePowerWeight &Zero() {
85     static const SparsePowerWeight zero(SparseTupleWeight<W, K>::Zero());
86     return zero;
87   }
88
89   static const SparsePowerWeight &One() {
90     static const SparsePowerWeight one(SparseTupleWeight<W, K>::One());
91     return one;
92   }
93
94   static const SparsePowerWeight &NoWeight() {
95     static const SparsePowerWeight no_weight(
96         SparseTupleWeight<W, K>::NoWeight());
97     return no_weight;
98   }
99
100   // Overide this: Overwrite the Type method to reflect the key type if using
101   // a non-default key type.
102   static const string &Type() {
103     static string type;
104     if (type.empty()) {
105       type = W::Type() + "_^n";
106       if (sizeof(K) != sizeof(uint32)) {
107         type += "_" + std::to_string(CHAR_BIT * sizeof(K));
108       }
109     }
110     return type;
111   }
112
113   static constexpr uint64 Properties() {
114     return W::Properties() &
115            (kLeftSemiring | kRightSemiring | kCommutative | kIdempotent);
116   }
117
118   SparsePowerWeight Quantize(float delta = kDelta) const {
119     return SparsePowerWeight(SparseTupleWeight<W, K>::Quantize(delta));
120   }
121
122   ReverseWeight Reverse() const {
123     return ReverseWeight(SparseTupleWeight<W, K>::Reverse());
124   }
125 };
126
127 // Semimodule plus operation.
128 template <class W, class K>
129 inline SparsePowerWeight<W, K> Plus(const SparsePowerWeight<W, K> &w1,
130                                     const SparsePowerWeight<W, K> &w2) {
131   SparsePowerWeight<W, K> result;
132   SparseTupleWeightPlusMapper<W, K> operator_mapper;
133   SparseTupleWeightMap(&result, w1, w2, operator_mapper);
134   return result;
135 }
136
137 // Semimodule times operation.
138 template <class W, class K>
139 inline SparsePowerWeight<W, K> Times(const SparsePowerWeight<W, K> &w1,
140                                      const SparsePowerWeight<W, K> &w2) {
141   SparsePowerWeight<W, K> result;
142   SparseTupleWeightTimesMapper<W, K> operator_mapper;
143   SparseTupleWeightMap(&result, w1, w2, operator_mapper);
144   return result;
145 }
146
147 // Semimodule divide operation.
148 template <class W, class K>
149 inline SparsePowerWeight<W, K> Divide(const SparsePowerWeight<W, K> &w1,
150                                       const SparsePowerWeight<W, K> &w2,
151                                       DivideType type = DIVIDE_ANY) {
152   SparsePowerWeight<W, K> result;
153   SparseTupleWeightDivideMapper<W, K> operator_mapper(type);
154   SparseTupleWeightMap(&result, w1, w2, operator_mapper);
155   return result;
156 }
157
158 // Semimodule dot product operation.
159 template <class W, class K>
160 inline const W &DotProduct(const SparsePowerWeight<W, K> &w1,
161                            const SparsePowerWeight<W, K> &w2) {
162   const SparsePowerWeight<W, K> product = Times(w1, w2);
163   W result(W::Zero());
164   for (SparseTupleWeightIterator<W, K> it(product); !it.Done(); it.Next()) {
165     result = Plus(result, it.Value().second);
166   }
167   return result;
168 }
169
170 template <class W, class K>
171 inline bool ApproxEqual(const SparsePowerWeight<W, K> &w1,
172                         const SparsePowerWeight<W, K> &w2,
173                         float delta = kDelta) {
174   SparseTupleWeight<W, K> result;
175   SparseTupleWeightApproxMapper<W, K> operator_mapper(kDelta);
176   SparseTupleWeightMap(&result, w1, w2, operator_mapper);
177   return result == SparsePowerWeight<W, K>::One();
178 }
179
180 template <class W, class K>
181 inline SparsePowerWeight<W, K> Times(const W &k,
182                                      const SparsePowerWeight<W, K> &w2) {
183   const SparseTupleWeight<W, K> t2(k);
184   const SparsePowerWeight<W, K> w1(t2);
185   return Times(w1, w2);
186 }
187
188 template <class W, class K>
189 inline SparsePowerWeight<W, K> Times(const SparsePowerWeight<W, K> &w1,
190                                      const W &k) {
191   const SparseTupleWeight<W, K> t2(k);
192   const SparsePowerWeight<W, K> w2(t2);
193   return Times(w1, w2);
194 }
195
196 template <class W, class K>
197 inline SparsePowerWeight<W, K> Divide(const SparsePowerWeight<W, K> &w1,
198                                       const W &k,
199                                       DivideType divide_type = DIVIDE_ANY) {
200   const SparseTupleWeight<W, K> t2(k);
201   const SparsePowerWeight<W, K> w2(t2);
202   return Divide(w1, w2, divide_type);
203 }
204
205 // This function object generates weights over the Cartesian power of rank
206 // n over the underlying weight. This is intended primarily for testing.
207 template <class W, class K>
208 class WeightGenerate<SparsePowerWeight<W, K>> {
209  public:
210   using Weight = SparsePowerWeight<W, K>;
211   using Generate = WeightGenerate<W>;
212
213   explicit WeightGenerate(bool allow_zero = true,
214                           size_t sparse_power_rank = 3)
215       : generate_(allow_zero), sparse_power_rank_(sparse_power_rank) {}
216
217   Weight operator()() const {
218     Weight weight;
219     for (size_t i = 1; i <= sparse_power_rank_; ++i) {
220       weight.Push(i, generate_(), true);
221     }
222     return weight;
223   }
224
225  private:
226   const Generate generate_;
227   const size_t sparse_power_rank_;
228 };
229
230 }  // namespace fst
231
232 #endif  // FST_LIB_SPARSE_POWER_WEIGHT_H_