Imported Upstream version 1.6.6
[platform/upstream/openfst.git] / src / lib / symbol-table.cc
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
5
6 #include <fst/symbol-table.h>
7
8 #include <fst/flags.h>
9 #include <fst/log.h>
10
11 #include <fstream>
12 #include <fst/util.h>
13
14 DEFINE_bool(fst_compat_symbols, true,
15             "Require symbol tables to match when appropriate");
16 DEFINE_string(fst_field_separator, "\t ",
17               "Set of characters used as a separator between printed fields");
18
19 namespace fst {
20
21 // Maximum line length in textual symbols file.
22 static constexpr int kLineLen = 8096;
23
24 // Identifies stream data as a symbol table (and its endianity).
25 static constexpr int32 kSymbolTableMagicNumber = 2125658996;
26
27 SymbolTableTextOptions::SymbolTableTextOptions(bool allow_negative_labels)
28     : allow_negative_labels(allow_negative_labels),
29       fst_field_separator(FLAGS_fst_field_separator) {}
30
31 namespace internal {
32
33 SymbolTableImpl *SymbolTableImpl::ReadText(std::istream &strm,
34                                            const string &filename,
35                                            const SymbolTableTextOptions &opts) {
36   std::unique_ptr<SymbolTableImpl> impl(new SymbolTableImpl(filename));
37   int64 nline = 0;
38   char line[kLineLen];
39   while (!strm.getline(line, kLineLen).fail()) {
40     ++nline;
41     std::vector<char *> col;
42     auto separator = opts.fst_field_separator + "\n";
43     SplitString(line, separator.c_str(), &col, true);
44     if (col.empty()) continue;  // Empty line.
45     if (col.size() != 2) {
46       LOG(ERROR) << "SymbolTable::ReadText: Bad number of columns ("
47                  << col.size() << "), "
48                  << "file = " << filename << ", line = " << nline << ":<"
49                  << line << ">";
50       return nullptr;
51     }
52     const char *symbol = col[0];
53     const char *value = col[1];
54     char *p;
55     const auto key = strtoll(value, &p, 10);
56     if (p < value + strlen(value) ||
57        (!opts.allow_negative_labels && key < 0) || key == -1) {
58       LOG(ERROR) << "SymbolTable::ReadText: Bad non-negative integer \""
59                  << value << "\", "
60                  << "file = " << filename << ", line = " << nline;
61       return nullptr;
62     }
63     impl->AddSymbol(symbol, key);
64   }
65   return impl.release();
66 }
67
68 void SymbolTableImpl::MaybeRecomputeCheckSum() const {
69   {
70     ReaderMutexLock check_sum_lock(&check_sum_mutex_);
71     if (check_sum_finalized_) return;
72   }
73   // We'll acquire an exclusive lock to recompute the checksums.
74   MutexLock check_sum_lock(&check_sum_mutex_);
75   if (check_sum_finalized_) {  // Another thread (coming in around the same time
76     return;                    // might have done it already). So we recheck.
77   }
78   // Calculates the original label-agnostic checksum.
79   CheckSummer check_sum;
80   for (size_t i = 0; i < symbols_.size(); ++i) {
81     const auto &symbol = symbols_.GetSymbol(i);
82     check_sum.Update(symbol.data(), symbol.size());
83     check_sum.Update("", 1);
84   }
85   check_sum_string_ = check_sum.Digest();
86   // Calculates the safer, label-dependent checksum.
87   CheckSummer labeled_check_sum;
88   for (int64 i = 0; i < dense_key_limit_; ++i) {
89     std::ostringstream line;
90     line << symbols_.GetSymbol(i) << '\t' << i;
91     labeled_check_sum.Update(line.str().data(), line.str().size());
92   }
93   using citer = map<int64, int64>::const_iterator;
94   for (citer it = key_map_.begin(); it != key_map_.end(); ++it) {
95     // TODO(tombagby, 2013-11-22) This line maintains a bug that ignores
96     // negative labels in the checksum that too many tests rely on.
97     if (it->first < dense_key_limit_) continue;
98     std::ostringstream line;
99     line << symbols_.GetSymbol(it->second) << '\t' << it->first;
100     labeled_check_sum.Update(line.str().data(), line.str().size());
101   }
102   labeled_check_sum_string_ = labeled_check_sum.Digest();
103   check_sum_finalized_ = true;
104 }
105
106 int64 SymbolTableImpl::AddSymbol(const string &symbol, int64 key) {
107   if (key == -1) return key;
108   const std::pair<int64, bool> &insert_key = symbols_.InsertOrFind(symbol);
109   if (!insert_key.second) {
110     auto key_already = GetNthKey(insert_key.first);
111     if (key_already == key) return key;
112     VLOG(1) << "SymbolTable::AddSymbol: symbol = " << symbol
113             << " already in symbol_map_ with key = " << key_already
114             << " but supplied new key = " << key << " (ignoring new key)";
115     return key_already;
116   }
117   if (key == (symbols_.size() - 1) && key == dense_key_limit_) {
118     ++dense_key_limit_;
119   } else {
120     idx_key_.push_back(key);
121     key_map_[key] = symbols_.size() - 1;
122   }
123   if (key >= available_key_) available_key_ = key + 1;
124   check_sum_finalized_ = false;
125   return key;
126 }
127
128 // TODO(rybach): Consider a more efficient implementation which re-uses holes in
129 // the dense-key range or re-arranges the dense-key range from time to time.
130 void SymbolTableImpl::RemoveSymbol(const int64 key) {
131   auto idx = key;
132   if (key < 0 || key >= dense_key_limit_) {
133     auto iter = key_map_.find(key);
134     if (iter == key_map_.end()) return;
135     idx = iter->second;
136     key_map_.erase(iter);
137   }
138   if (idx < 0 || idx >= symbols_.size()) return;
139   symbols_.RemoveSymbol(idx);
140   // Removed one symbol, all indexes > idx are shifted by -1.
141   for (auto &k : key_map_) {
142     if (k.second > idx) --k.second;
143   }
144   if (key >= 0 && key < dense_key_limit_) {
145     // Removal puts a hole in the dense key range. Adjusts range to [0, key).
146     const auto new_dense_key_limit = key;
147     for (int64 i = key + 1; i < dense_key_limit_; ++i) {
148       key_map_[i] = i - 1;
149     }
150     // Moves existing values in idx_key to new place.
151     idx_key_.resize(symbols_.size() - new_dense_key_limit);
152     for (int64 i = symbols_.size(); i >= dense_key_limit_; --i) {
153       idx_key_[i - new_dense_key_limit - 1] = idx_key_[i - dense_key_limit_];
154     }
155     // Adds indexes for previously dense keys.
156     for (int64 i = new_dense_key_limit; i < dense_key_limit_ - 1; ++i) {
157       idx_key_[i - new_dense_key_limit] = i + 1;
158     }
159     dense_key_limit_ = new_dense_key_limit;
160   } else {
161     // Remove entry for removed index in idx_key.
162     for (int64 i = idx - dense_key_limit_; i < idx_key_.size() - 1; ++i) {
163       idx_key_[i] = idx_key_[i + 1];
164     }
165     idx_key_.pop_back();
166   }
167   if (key == available_key_ - 1) available_key_ = key;
168 }
169
170 SymbolTableImpl *SymbolTableImpl::Read(std::istream &strm,
171                                        const SymbolTableReadOptions &opts) {
172   int32 magic_number = 0;
173   ReadType(strm, &magic_number);
174   if (strm.fail()) {
175     LOG(ERROR) << "SymbolTable::Read: Read failed";
176     return nullptr;
177   }
178   string name;
179   ReadType(strm, &name);
180   std::unique_ptr<SymbolTableImpl> impl(new SymbolTableImpl(name));
181   ReadType(strm, &impl->available_key_);
182   int64 size;
183   ReadType(strm, &size);
184   if (strm.fail()) {
185     LOG(ERROR) << "SymbolTable::Read: Read failed";
186     return nullptr;
187   }
188   string symbol;
189   int64 key;
190   impl->check_sum_finalized_ = false;
191   for (int64 i = 0; i < size; ++i) {
192     ReadType(strm, &symbol);
193     ReadType(strm, &key);
194     if (strm.fail()) {
195       LOG(ERROR) << "SymbolTable::Read: Read failed";
196       return nullptr;
197     }
198     impl->AddSymbol(symbol, key);
199   }
200   return impl.release();
201 }
202
203 bool SymbolTableImpl::Write(std::ostream &strm) const {
204   WriteType(strm, kSymbolTableMagicNumber);
205   WriteType(strm, name_);
206   WriteType(strm, available_key_);
207   int64 size = symbols_.size();
208   WriteType(strm, size);
209   for (int64 i = 0; i < size; ++i) {
210     auto key = (i < dense_key_limit_) ? i : idx_key_[i - dense_key_limit_];
211     WriteType(strm, symbols_.GetSymbol(i));
212     WriteType(strm, key);
213   }
214   strm.flush();
215   if (strm.fail()) {
216     LOG(ERROR) << "SymbolTable::Write: Write failed";
217     return false;
218   }
219   return true;
220 }
221
222 }  // namespace internal
223
224 constexpr int64 SymbolTable::kNoSymbol;
225
226 void SymbolTable::AddTable(const SymbolTable &table) {
227   MutateCheck();
228   for (SymbolTableIterator iter(table); !iter.Done(); iter.Next()) {
229     impl_->AddSymbol(iter.Symbol());
230   }
231 }
232
233 bool SymbolTable::WriteText(std::ostream &strm,
234                             const SymbolTableTextOptions &opts) const {
235   if (opts.fst_field_separator.empty()) {
236     LOG(ERROR) << "Missing required field separator";
237     return false;
238   }
239   bool once_only = false;
240   for (SymbolTableIterator iter(*this); !iter.Done(); iter.Next()) {
241     std::ostringstream line;
242     if (iter.Value() < 0 && !opts.allow_negative_labels && !once_only) {
243       LOG(WARNING) << "Negative symbol table entry when not allowed";
244       once_only = true;
245     }
246     line << iter.Symbol() << opts.fst_field_separator[0] << iter.Value()
247          << '\n';
248     strm.write(line.str().data(), line.str().length());
249   }
250   return true;
251 }
252
253 namespace internal {
254
255 DenseSymbolMap::DenseSymbolMap()
256     : empty_(-1), buckets_(1 << 4), hash_mask_(buckets_.size() - 1) {
257   std::uninitialized_fill(buckets_.begin(), buckets_.end(), empty_);
258 }
259
260 DenseSymbolMap::DenseSymbolMap(const DenseSymbolMap &x)
261     : empty_(-1),
262       symbols_(x.symbols_.size()),
263       buckets_(x.buckets_),
264       hash_mask_(x.hash_mask_) {
265   for (size_t i = 0; i < symbols_.size(); ++i) {
266     const auto sz = strlen(x.symbols_[i]) + 1;
267     auto *cpy = new char[sz];
268     memcpy(cpy, x.symbols_[i], sz);
269     symbols_[i] = cpy;
270   }
271 }
272
273 DenseSymbolMap::~DenseSymbolMap() {
274   for (size_t i = 0; i < symbols_.size(); ++i) delete[] symbols_[i];
275 }
276
277 std::pair<int64, bool> DenseSymbolMap::InsertOrFind(const string &key) {
278   static constexpr float kMaxOccupancyRatio = 0.75;  // Grows when 75% full.
279   if (symbols_.size() >= kMaxOccupancyRatio * buckets_.size()) {
280     Rehash(buckets_.size() * 2);
281   }
282   size_t idx = str_hash_(key) & hash_mask_;
283   while (buckets_[idx] != empty_) {
284     const auto stored_value = buckets_[idx];
285     if (!strcmp(symbols_[stored_value], key.c_str())) {
286       return std::make_pair(stored_value, false);
287     }
288     idx = (idx + 1) & hash_mask_;
289   }
290   auto next = symbols_.size();
291   buckets_[idx] = next;
292   symbols_.push_back(NewSymbol(key));
293   return std::make_pair(next, true);
294 }
295
296 int64 DenseSymbolMap::Find(const string &key) const {
297   size_t idx = str_hash_(key) & hash_mask_;
298   while (buckets_[idx] != empty_) {
299     const auto stored_value = buckets_[idx];
300     if (!strcmp(symbols_[stored_value], key.c_str())) {
301       return stored_value;
302     }
303     idx = (idx + 1) & hash_mask_;
304   }
305   return buckets_[idx];
306 }
307
308 void DenseSymbolMap::Rehash(size_t num_buckets) {
309   buckets_.resize(num_buckets);
310   hash_mask_ = buckets_.size() - 1;
311   std::uninitialized_fill(buckets_.begin(), buckets_.end(), empty_);
312   for (size_t i = 0; i < symbols_.size(); ++i) {
313     size_t idx = str_hash_(string(symbols_[i])) & hash_mask_;
314     while (buckets_[idx] != empty_) {
315       idx = (idx + 1) & hash_mask_;
316     }
317     buckets_[idx] = i;
318   }
319 }
320
321 const char *DenseSymbolMap::NewSymbol(const string &sym) {
322   auto num = sym.size() + 1;
323   auto newstr = new char[num];
324   memcpy(newstr, sym.c_str(), num);
325   return newstr;
326 }
327
328 void DenseSymbolMap::RemoveSymbol(size_t idx) {
329   delete[] symbols_[idx];
330   symbols_.erase(symbols_.begin() + idx);
331   Rehash(buckets_.size());
332 }
333
334 }  // namespace internal
335
336 bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
337                    bool warning) {
338   // Flag can explicitly override this check.
339   if (!FLAGS_fst_compat_symbols) return true;
340   if (syms1 && syms2 &&
341       (syms1->LabeledCheckSum() != syms2->LabeledCheckSum())) {
342     if (warning) {
343       LOG(WARNING) << "CompatSymbols: Symbol table checksums do not match. "
344                    << "Table sizes are " << syms1->NumSymbols() << " and "
345                    << syms2->NumSymbols();
346     }
347     return false;
348   } else {
349     return true;
350   }
351 }
352
353 void SymbolTableToString(const SymbolTable *table, string *result) {
354   std::ostringstream ostrm;
355   table->Write(ostrm);
356   *result = ostrm.str();
357 }
358
359 SymbolTable *StringToSymbolTable(const string &str) {
360   std::istringstream istrm(str);
361   return SymbolTable::Read(istrm, SymbolTableReadOptions());
362 }
363 }  // namespace fst