d23bd8d4ac680707ac9aad1359e07fc5b60b687b
[platform/upstream/openfst.git] / src / include / fst / sparse-tuple-weight.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Sparse version of tuple-weight, based on tuple-weight.h.
5 // Internally stores sparse key, value pairs in linked list. The default value
6 // element is the assumed value of unset keys. Internal singleton
7 // implementation that stores first key, value pair as a initialized member
8 // variable to avoid unnecessary allocation on heap. Use
9 // SparseTupleWeightIterator to iterate through the key,value pairs. Note:
10 // this does NOT iterate through the default value.
11 //
12 // Sparse tuple weight set operation definitions.
13
14 #ifndef FST_LIB_SPARSE_TUPLE_WEIGHT_H_
15 #define FST_LIB_SPARSE_TUPLE_WEIGHT_H_
16
17 #include <list>
18 #include <stack>
19 #include <string>
20 #include <unordered_map>
21 #include <utility>
22
23
24 #include <fst/weight.h>
25
26
27 namespace fst {
28
29 template <class W, class K>
30 class SparseTupleWeightIterator;
31
32 // Arbitrary dimension tuple weight, stored as a sorted linked-list.
33 // W is any weight class, and K is the key value type. kNoKey (-1) is reserved
34 // for internal use.
35 template <class W, class K = int>
36 class SparseTupleWeight {
37  public:
38   using ReverseWeight = SparseTupleWeight<typename W::ReverseWeight, K>;
39
40   using Pair = std::pair<K, W>;
41
42   constexpr static K kNoKey = -1;
43
44   SparseTupleWeight() { Init(); }
45
46   template <class Iterator>
47   SparseTupleWeight(Iterator begin, Iterator end) {
48     Init();
49     // Assumes input iterator is sorted.
50     for (auto it = begin; it != end; ++it) Push(*it);
51   }
52
53   SparseTupleWeight(const K &key, const W &weight) {
54     Init();
55     Push(key, weight);
56   }
57
58   explicit SparseTupleWeight(const W &weight) { Init(weight); }
59
60   SparseTupleWeight(const SparseTupleWeight &weight) {
61     Init(weight.DefaultValue());
62     SetDefaultValue(weight.DefaultValue());
63     for (SparseTupleWeightIterator<W, K> it(weight); !it.Done(); it.Next()) {
64       Push(it.Value());
65     }
66   }
67
68   static const SparseTupleWeight &Zero() {
69     static const SparseTupleWeight zero(W::Zero());
70     return zero;
71   }
72
73   static const SparseTupleWeight &One() {
74     static const SparseTupleWeight one(W::One());
75     return one;
76   }
77
78   static const SparseTupleWeight &NoWeight() {
79     static const SparseTupleWeight no_weight(W::NoWeight());
80     return no_weight;
81   }
82
83   std::istream &Read(std::istream &strm) {
84     ReadType(strm, &default_);
85     ReadType(strm, &first_);
86     return ReadType(strm, &rest_);
87   }
88
89   std::ostream &Write(std::ostream &strm) const {
90     WriteType(strm, default_);
91     WriteType(strm, first_);
92     return WriteType(strm, rest_);
93   }
94
95   SparseTupleWeight &operator=(const SparseTupleWeight &weight) {
96     if (this == &weight) return *this;  // Checks for identity.
97     Init(weight.DefaultValue());
98     for (SparseTupleWeightIterator<W, K> it(weight); !it.Done(); it.Next()) {
99       Push(it.Value());
100     }
101     return *this;
102   }
103
104   bool Member() const {
105     if (!DefaultValue().Member()) return false;
106     for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
107       if (!it.Value().second.Member()) return false;
108     }
109     return true;
110   }
111
112   // Assumes H() function exists for the hash of the key value.
113   size_t Hash() const {
114     size_t h = 0;
115     static const std::hash<K> H;
116     for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
117       h = 5 * h + H(it.Value().first);
118       h = 13 * h + it.Value().second.Hash();
119     }
120     return h;
121   }
122
123   SparseTupleWeight Quantize(float delta = kDelta) const {
124     SparseTupleWeight weight;
125     for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
126       weight.Push(it.Value().first, it.Value().second.Quantize(delta));
127     }
128     return weight;
129   }
130
131   ReverseWeight Reverse() const {
132     SparseTupleWeight weight;
133     for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
134       weight.Push(it.Value().first, it.Value().second.Reverse());
135     }
136     return ReverseWeight(weight);
137   }
138
139   void Init(const W &default_value = W::Zero()) {
140     first_.first = kNoKey;
141     // Initialized to the reserved key value.
142     default_ = default_value;
143     rest_.clear();
144   }
145
146   size_t Size() const {
147     if (first_.first == kNoKey) {
148       return 0;
149     } else {
150       return rest_.size() + 1;
151     }
152   }
153
154   inline void Push(const K &key, const W &weight,
155                    bool default_value_check = true) {
156     Push(std::make_pair(key, weight), default_value_check);
157   }
158
159   inline void Push(const Pair &pair, bool default_value_check = true) {
160     if (default_value_check && pair.second == default_) return;
161     if (first_.first == kNoKey) {
162       first_ = pair;
163     } else {
164       rest_.push_back(pair);
165     }
166   }
167
168   void SetDefaultValue(const W &value) { default_ = value; }
169
170   const W &DefaultValue() const { return default_; }
171
172  private:
173   // Assumed default value of uninitialized keys, by default W::Zero().
174   W default_;
175
176   // Key values pairs are first stored in first_, then fill rest_ this way we
177   // can avoid dynamic allocation in the common case where the weight is a
178   // single key/value pair.
179   Pair first_;
180   std::list<Pair> rest_;
181
182   friend class SparseTupleWeightIterator<W, K>;
183 };
184
185 template <class W, class K>
186 class SparseTupleWeightIterator {
187  public:
188   using Pair = typename SparseTupleWeight<W, K>::Pair;
189   using const_iterator = typename std::list<Pair>::const_iterator;
190   using iterator = typename std::list<Pair>::iterator;
191
192   explicit SparseTupleWeightIterator(const SparseTupleWeight<W, K> &weight)
193       : first_(weight.first_),
194         rest_(weight.rest_),
195         init_(true),
196         iter_(rest_.begin()) {}
197
198   bool Done() const {
199     if (init_) {
200       return first_.first == SparseTupleWeight<W, K>::kNoKey;
201     } else {
202       return iter_ == rest_.end();
203     }
204   }
205
206   const Pair &Value() const { return init_ ? first_ : *iter_; }
207
208   void Next() {
209     if (init_) {
210       init_ = false;
211     } else {
212       ++iter_;
213     }
214   }
215
216   void Reset() {
217     init_ = true;
218     iter_ = rest_.begin();
219   }
220
221  private:
222   const Pair &first_;
223   const std::list<Pair> &rest_;
224   bool init_;  // In the initialized state?
225   const_iterator iter_;
226 };
227
228 template <class W, class K, class M>
229 inline void SparseTupleWeightMap(SparseTupleWeight<W, K> *result,
230                                  const SparseTupleWeight<W, K> &w1,
231                                  const SparseTupleWeight<W, K> &w2,
232                                  const M &operator_mapper) {
233   SparseTupleWeightIterator<W, K> w1_it(w1);
234   SparseTupleWeightIterator<W, K> w2_it(w2);
235   const auto &v1_def = w1.DefaultValue();
236   const auto &v2_def = w2.DefaultValue();
237   result->SetDefaultValue(operator_mapper.Map(0, v1_def, v2_def));
238   while (!w1_it.Done() || !w2_it.Done()) {
239     const auto &k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first;
240     const auto &k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first;
241     const auto &v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second;
242     const auto &v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second;
243     if (k1 == k2) {
244       result->Push(k1, operator_mapper.Map(k1, v1, v2));
245       if (!w1_it.Done()) w1_it.Next();
246       if (!w2_it.Done()) w2_it.Next();
247     } else if (k1 < k2) {
248       result->Push(k1, operator_mapper.Map(k1, v1, v2_def));
249       w1_it.Next();
250     } else {
251       result->Push(k2, operator_mapper.Map(k2, v1_def, v2));
252       w2_it.Next();
253     }
254   }
255 }
256
257 template <class W, class K>
258 inline bool operator==(const SparseTupleWeight<W, K> &w1,
259                        const SparseTupleWeight<W, K> &w2) {
260   const auto &v1_def = w1.DefaultValue();
261   const auto &v2_def = w2.DefaultValue();
262   if (v1_def != v2_def) return false;
263   SparseTupleWeightIterator<W, K> w1_it(w1);
264   SparseTupleWeightIterator<W, K> w2_it(w2);
265   while (!w1_it.Done() || !w2_it.Done()) {
266     const auto &k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first;
267     const auto &k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first;
268     const auto &v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second;
269     const auto &v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second;
270     if (k1 == k2) {
271       if (v1 != v2) return false;
272       if (!w1_it.Done()) w1_it.Next();
273       if (!w2_it.Done()) w2_it.Next();
274     } else if (k1 < k2) {
275       if (v1 != v2_def) return false;
276       w1_it.Next();
277     } else {
278       if (v1_def != v2) return false;
279       w2_it.Next();
280     }
281   }
282   return true;
283 }
284
285 template <class W, class K>
286 inline bool operator!=(const SparseTupleWeight<W, K> &w1,
287                        const SparseTupleWeight<W, K> &w2) {
288   return !(w1 == w2);
289 }
290
291 template <class W, class K>
292 inline std::ostream &operator<<(std::ostream &strm,
293                                 const SparseTupleWeight<W, K> &weight) {
294   CompositeWeightWriter writer(strm);
295   writer.WriteBegin();
296   writer.WriteElement(weight.DefaultValue());
297   for (SparseTupleWeightIterator<W, K> it(weight); !it.Done(); it.Next()) {
298     writer.WriteElement(it.Value().first);
299     writer.WriteElement(it.Value().second);
300   }
301   writer.WriteEnd();
302   return strm;
303 }
304
305 template <class W, class K>
306 inline std::istream &operator>>(std::istream &strm,
307                                 SparseTupleWeight<W, K> &weight) {
308   CompositeWeightReader reader(strm);
309   reader.ReadBegin();
310   W def;
311   bool more = reader.ReadElement(&def);
312   weight.Init(def);
313   while (more) {
314     K key;
315     reader.ReadElement(&key);
316     W v;
317     more = reader.ReadElement(&v);
318     weight.Push(key, v);
319   }
320   reader.ReadEnd();
321   return strm;
322 }
323
324 }  // namespace fst
325
326 #endif  // FST_LIB_SPARSE_TUPLE_WEIGHT_H_