1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
6 #include <fst/symbol-table.h>
14 DEFINE_bool(fst_compat_symbols, true,
15 "Require symbol tables to match when appropriate");
16 DEFINE_string(fst_field_separator, "\t ",
17 "Set of characters used as a separator between printed fields");
21 // Maximum line length in textual symbols file.
22 static constexpr int kLineLen = 8096;
24 // Identifies stream data as a symbol table (and its endianity).
25 static constexpr int32 kSymbolTableMagicNumber = 2125658996;
27 SymbolTableTextOptions::SymbolTableTextOptions(bool allow_negative_labels)
28 : allow_negative_labels(allow_negative_labels),
29 fst_field_separator(FLAGS_fst_field_separator) {}
33 SymbolTableImpl *SymbolTableImpl::ReadText(std::istream &strm,
34 const string &filename,
35 const SymbolTableTextOptions &opts) {
36 std::unique_ptr<SymbolTableImpl> impl(new SymbolTableImpl(filename));
39 while (!strm.getline(line, kLineLen).fail()) {
41 std::vector<char *> col;
42 auto separator = opts.fst_field_separator + "\n";
43 SplitString(line, separator.c_str(), &col, true);
44 if (col.empty()) continue; // Empty line.
45 if (col.size() != 2) {
46 LOG(ERROR) << "SymbolTable::ReadText: Bad number of columns ("
47 << col.size() << "), "
48 << "file = " << filename << ", line = " << nline << ":<"
52 const char *symbol = col[0];
53 const char *value = col[1];
55 const auto key = strtoll(value, &p, 10);
56 if (p < value + strlen(value) ||
57 (!opts.allow_negative_labels && key < 0) || key == -1) {
58 LOG(ERROR) << "SymbolTable::ReadText: Bad non-negative integer \""
60 << "file = " << filename << ", line = " << nline;
63 impl->AddSymbol(symbol, key);
65 return impl.release();
68 void SymbolTableImpl::MaybeRecomputeCheckSum() const {
70 ReaderMutexLock check_sum_lock(&check_sum_mutex_);
71 if (check_sum_finalized_) return;
73 // We'll acquire an exclusive lock to recompute the checksums.
74 MutexLock check_sum_lock(&check_sum_mutex_);
75 if (check_sum_finalized_) { // Another thread (coming in around the same time
76 return; // might have done it already). So we recheck.
78 // Calculates the original label-agnostic checksum.
79 CheckSummer check_sum;
80 for (size_t i = 0; i < symbols_.size(); ++i) {
81 const auto &symbol = symbols_.GetSymbol(i);
82 check_sum.Update(symbol.data(), symbol.size());
83 check_sum.Update("", 1);
85 check_sum_string_ = check_sum.Digest();
86 // Calculates the safer, label-dependent checksum.
87 CheckSummer labeled_check_sum;
88 for (int64 i = 0; i < dense_key_limit_; ++i) {
89 std::ostringstream line;
90 line << symbols_.GetSymbol(i) << '\t' << i;
91 labeled_check_sum.Update(line.str().data(), line.str().size());
93 using citer = map<int64, int64>::const_iterator;
94 for (citer it = key_map_.begin(); it != key_map_.end(); ++it) {
95 // TODO(tombagby, 2013-11-22) This line maintains a bug that ignores
96 // negative labels in the checksum that too many tests rely on.
97 if (it->first < dense_key_limit_) continue;
98 std::ostringstream line;
99 line << symbols_.GetSymbol(it->second) << '\t' << it->first;
100 labeled_check_sum.Update(line.str().data(), line.str().size());
102 labeled_check_sum_string_ = labeled_check_sum.Digest();
103 check_sum_finalized_ = true;
106 int64 SymbolTableImpl::AddSymbol(const string &symbol, int64 key) {
107 if (key == -1) return key;
108 const std::pair<int64, bool> &insert_key = symbols_.InsertOrFind(symbol);
109 if (!insert_key.second) {
110 auto key_already = GetNthKey(insert_key.first);
111 if (key_already == key) return key;
112 VLOG(1) << "SymbolTable::AddSymbol: symbol = " << symbol
113 << " already in symbol_map_ with key = " << key_already
114 << " but supplied new key = " << key << " (ignoring new key)";
117 if (key == (symbols_.size() - 1) && key == dense_key_limit_) {
120 idx_key_.push_back(key);
121 key_map_[key] = symbols_.size() - 1;
123 if (key >= available_key_) available_key_ = key + 1;
124 check_sum_finalized_ = false;
128 // TODO(rybach): Consider a more efficient implementation which re-uses holes in
129 // the dense-key range or re-arranges the dense-key range from time to time.
130 void SymbolTableImpl::RemoveSymbol(const int64 key) {
132 if (key < 0 || key >= dense_key_limit_) {
133 auto iter = key_map_.find(key);
134 if (iter == key_map_.end()) return;
136 key_map_.erase(iter);
138 if (idx < 0 || idx >= symbols_.size()) return;
139 symbols_.RemoveSymbol(idx);
140 // Removed one symbol, all indexes > idx are shifted by -1.
141 for (auto &k : key_map_) {
142 if (k.second > idx) --k.second;
144 if (key >= 0 && key < dense_key_limit_) {
145 // Removal puts a hole in the dense key range. Adjusts range to [0, key).
146 const auto new_dense_key_limit = key;
147 for (int64 i = key + 1; i < dense_key_limit_; ++i) {
150 // Moves existing values in idx_key to new place.
151 idx_key_.resize(symbols_.size() - new_dense_key_limit);
152 for (int64 i = symbols_.size(); i >= dense_key_limit_; --i) {
153 idx_key_[i - new_dense_key_limit - 1] = idx_key_[i - dense_key_limit_];
155 // Adds indexes for previously dense keys.
156 for (int64 i = new_dense_key_limit; i < dense_key_limit_ - 1; ++i) {
157 idx_key_[i - new_dense_key_limit] = i + 1;
159 dense_key_limit_ = new_dense_key_limit;
161 // Remove entry for removed index in idx_key.
162 for (int64 i = idx - dense_key_limit_; i < idx_key_.size() - 1; ++i) {
163 idx_key_[i] = idx_key_[i + 1];
167 if (key == available_key_ - 1) available_key_ = key;
170 SymbolTableImpl *SymbolTableImpl::Read(std::istream &strm,
171 const SymbolTableReadOptions &opts) {
172 int32 magic_number = 0;
173 ReadType(strm, &magic_number);
175 LOG(ERROR) << "SymbolTable::Read: Read failed";
179 ReadType(strm, &name);
180 std::unique_ptr<SymbolTableImpl> impl(new SymbolTableImpl(name));
181 ReadType(strm, &impl->available_key_);
183 ReadType(strm, &size);
185 LOG(ERROR) << "SymbolTable::Read: Read failed";
190 impl->check_sum_finalized_ = false;
191 for (int64 i = 0; i < size; ++i) {
192 ReadType(strm, &symbol);
193 ReadType(strm, &key);
195 LOG(ERROR) << "SymbolTable::Read: Read failed";
198 impl->AddSymbol(symbol, key);
200 return impl.release();
203 bool SymbolTableImpl::Write(std::ostream &strm) const {
204 WriteType(strm, kSymbolTableMagicNumber);
205 WriteType(strm, name_);
206 WriteType(strm, available_key_);
207 int64 size = symbols_.size();
208 WriteType(strm, size);
209 for (int64 i = 0; i < size; ++i) {
210 auto key = (i < dense_key_limit_) ? i : idx_key_[i - dense_key_limit_];
211 WriteType(strm, symbols_.GetSymbol(i));
212 WriteType(strm, key);
216 LOG(ERROR) << "SymbolTable::Write: Write failed";
222 } // namespace internal
224 constexpr int64 SymbolTable::kNoSymbol;
226 void SymbolTable::AddTable(const SymbolTable &table) {
228 for (SymbolTableIterator iter(table); !iter.Done(); iter.Next()) {
229 impl_->AddSymbol(iter.Symbol());
233 bool SymbolTable::WriteText(std::ostream &strm,
234 const SymbolTableTextOptions &opts) const {
235 if (opts.fst_field_separator.empty()) {
236 LOG(ERROR) << "Missing required field separator";
239 bool once_only = false;
240 for (SymbolTableIterator iter(*this); !iter.Done(); iter.Next()) {
241 std::ostringstream line;
242 if (iter.Value() < 0 && !opts.allow_negative_labels && !once_only) {
243 LOG(WARNING) << "Negative symbol table entry when not allowed";
246 line << iter.Symbol() << opts.fst_field_separator[0] << iter.Value()
248 strm.write(line.str().data(), line.str().length());
255 DenseSymbolMap::DenseSymbolMap()
256 : empty_(-1), buckets_(1 << 4), hash_mask_(buckets_.size() - 1) {
257 std::uninitialized_fill(buckets_.begin(), buckets_.end(), empty_);
260 DenseSymbolMap::DenseSymbolMap(const DenseSymbolMap &x)
262 symbols_(x.symbols_.size()),
263 buckets_(x.buckets_),
264 hash_mask_(x.hash_mask_) {
265 for (size_t i = 0; i < symbols_.size(); ++i) {
266 const auto sz = strlen(x.symbols_[i]) + 1;
267 auto *cpy = new char[sz];
268 memcpy(cpy, x.symbols_[i], sz);
273 DenseSymbolMap::~DenseSymbolMap() {
274 for (size_t i = 0; i < symbols_.size(); ++i) delete[] symbols_[i];
277 std::pair<int64, bool> DenseSymbolMap::InsertOrFind(const string &key) {
278 static constexpr float kMaxOccupancyRatio = 0.75; // Grows when 75% full.
279 if (symbols_.size() >= kMaxOccupancyRatio * buckets_.size()) {
280 Rehash(buckets_.size() * 2);
282 size_t idx = str_hash_(key) & hash_mask_;
283 while (buckets_[idx] != empty_) {
284 const auto stored_value = buckets_[idx];
285 if (!strcmp(symbols_[stored_value], key.c_str())) {
286 return std::make_pair(stored_value, false);
288 idx = (idx + 1) & hash_mask_;
290 auto next = symbols_.size();
291 buckets_[idx] = next;
292 symbols_.push_back(NewSymbol(key));
293 return std::make_pair(next, true);
296 int64 DenseSymbolMap::Find(const string &key) const {
297 size_t idx = str_hash_(key) & hash_mask_;
298 while (buckets_[idx] != empty_) {
299 const auto stored_value = buckets_[idx];
300 if (!strcmp(symbols_[stored_value], key.c_str())) {
303 idx = (idx + 1) & hash_mask_;
305 return buckets_[idx];
308 void DenseSymbolMap::Rehash(size_t num_buckets) {
309 buckets_.resize(num_buckets);
310 hash_mask_ = buckets_.size() - 1;
311 std::uninitialized_fill(buckets_.begin(), buckets_.end(), empty_);
312 for (size_t i = 0; i < symbols_.size(); ++i) {
313 size_t idx = str_hash_(string(symbols_[i])) & hash_mask_;
314 while (buckets_[idx] != empty_) {
315 idx = (idx + 1) & hash_mask_;
321 const char *DenseSymbolMap::NewSymbol(const string &sym) {
322 auto num = sym.size() + 1;
323 auto newstr = new char[num];
324 memcpy(newstr, sym.c_str(), num);
328 void DenseSymbolMap::RemoveSymbol(size_t idx) {
329 delete[] symbols_[idx];
330 symbols_.erase(symbols_.begin() + idx);
331 Rehash(buckets_.size());
334 } // namespace internal
336 bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
338 // Flag can explicitly override this check.
339 if (!FLAGS_fst_compat_symbols) return true;
340 if (syms1 && syms2 &&
341 (syms1->LabeledCheckSum() != syms2->LabeledCheckSum())) {
343 LOG(WARNING) << "CompatSymbols: Symbol table checksums do not match. "
344 << "Table sizes are " << syms1->NumSymbols() << " and "
345 << syms2->NumSymbols();
353 void SymbolTableToString(const SymbolTable *table, string *result) {
354 std::ostringstream ostrm;
356 *result = ostrm.str();
359 SymbolTable *StringToSymbolTable(const string &str) {
360 std::istringstream istrm(str);
361 return SymbolTable::Read(istrm, SymbolTableReadOptions());