1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Functions and classes for the recursive replacement of FSTs.
11 #include <unordered_map>
17 #include <fst/cache.h>
18 #include <fst/expanded-fst.h>
19 #include <fst/fst-decl.h> // For optional argument declarations.
21 #include <fst/matcher.h>
22 #include <fst/replace-util.h>
23 #include <fst/state-table.h>
24 #include <fst/test-properties.h>
28 // Replace state tables have the form:
30 // template <class Arc, class P>
31 // class ReplaceStateTable {
33 // using Label = typename Arc::Label Label;
34 // using StateId = typename Arc::StateId;
36 // using PrefixId = P;
37 // using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
38 // using StackPrefix = ReplaceStackPrefix<Label, StateId>;
40 // // Required constructor.
42 // const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_list,
45 // // Required copy constructor that does not copy state.
46 // ReplaceStateTable(const ReplaceStateTable<Arc, PrefixId> &table);
48 // // Looks up state ID by tuple, adding it if it doesn't exist.
49 // StateId FindState(const StateTuple &tuple);
51 // // Looks up state tuple by ID.
52 // const StateTuple &Tuple(StateId id) const;
54 // // Lookus up prefix ID by stack prefix, adding it if it doesn't exist.
55 // PrefixId FindPrefixId(const StackPrefix &stack_prefix);
57 // // Looks up stack prefix by ID.
58 // const StackPrefix &GetStackPrefix(PrefixId id) const;
61 // Tuple that uniquely defines a state in replace.
62 template <class S, class P>
63 struct ReplaceStateTuple {
67 ReplaceStateTuple(PrefixId prefix_id = -1, StateId fst_id = kNoStateId,
68 StateId fst_state = kNoStateId)
69 : prefix_id(prefix_id), fst_id(fst_id), fst_state(fst_state) {}
71 PrefixId prefix_id; // Index in prefix table.
72 StateId fst_id; // Current FST being walked.
73 StateId fst_state; // Current state in FST being walked (not to be
74 // confused with the thse StateId of the combined FST).
77 // Equality of replace state tuples.
78 template <class StateId, class PrefixId>
79 inline bool operator==(const ReplaceStateTuple<StateId, PrefixId> &x,
80 const ReplaceStateTuple<StateId, PrefixId> &y) {
81 return x.prefix_id == y.prefix_id && x.fst_id == y.fst_id &&
82 x.fst_state == y.fst_state;
85 // Functor returning true for tuples corresponding to states in the root FST.
86 template <class StateId, class PrefixId>
87 class ReplaceRootSelector {
89 bool operator()(const ReplaceStateTuple<StateId, PrefixId> &tuple) const {
90 return tuple.prefix_id == 0;
94 // Functor for fingerprinting replace state tuples.
95 template <class StateId, class PrefixId>
96 class ReplaceFingerprint {
98 explicit ReplaceFingerprint(const std::vector<uint64> *size_array)
99 : size_array_(size_array) {}
101 uint64 operator()(const ReplaceStateTuple<StateId, PrefixId> &tuple) const {
102 return tuple.prefix_id * size_array_->back() +
103 size_array_->at(tuple.fst_id - 1) + tuple.fst_state;
107 const std::vector<uint64> *size_array_;
110 // Useful when the fst_state uniquely define the tuple.
111 template <class StateId, class PrefixId>
112 class ReplaceFstStateFingerprint {
114 uint64 operator()(const ReplaceStateTuple<StateId, PrefixId> &tuple) const {
115 return tuple.fst_state;
119 // A generic hash function for replace state tuples.
120 template <typename S, typename P>
123 size_t operator()(const ReplaceStateTuple<S, P>& t) const {
124 static constexpr auto prime0 = 7853;
125 static constexpr auto prime1 = 7867;
126 return t.prefix_id + t.fst_id * prime0 + t.fst_state * prime1;
130 // Container for stack prefix.
131 template <class Label, class StateId>
132 class ReplaceStackPrefix {
135 PrefixTuple(Label fst_id = kNoLabel, StateId nextstate = kNoStateId)
136 : fst_id(fst_id), nextstate(nextstate) {}
142 ReplaceStackPrefix() {}
144 ReplaceStackPrefix(const ReplaceStackPrefix &other)
145 : prefix_(other.prefix_) {}
147 void Push(StateId fst_id, StateId nextstate) {
148 prefix_.push_back(PrefixTuple(fst_id, nextstate));
151 void Pop() { prefix_.pop_back(); }
153 const PrefixTuple &Top() const { return prefix_[prefix_.size() - 1]; }
155 size_t Depth() const { return prefix_.size(); }
158 std::vector<PrefixTuple> prefix_;
161 // Equality stack prefix classes.
162 template <class Label, class StateId>
163 inline bool operator==(const ReplaceStackPrefix<Label, StateId> &x,
164 const ReplaceStackPrefix<Label, StateId> &y) {
165 if (x.prefix_.size() != y.prefix_.size()) return false;
166 for (size_t i = 0; i < x.prefix_.size(); ++i) {
167 if (x.prefix_[i].fst_id != y.prefix_[i].fst_id ||
168 x.prefix_[i].nextstate != y.prefix_[i].nextstate) {
175 // Hash function for stack prefix to prefix id.
176 template <class Label, class StateId>
177 class ReplaceStackPrefixHash {
179 size_t operator()(const ReplaceStackPrefix<Label, StateId> &prefix) const {
181 for (const auto &pair : prefix.prefix_) {
182 static constexpr auto prime = 7863;
183 sum += pair.fst_id + pair.nextstate * prime;
189 // Replace state tables.
191 // A two-level state table for replace. Warning: calls CountStates to compute
192 // the number of states of each component FST.
193 template <class Arc, class P = ssize_t>
194 class VectorHashReplaceStateTable {
196 using Label = typename Arc::Label;
197 using StateId = typename Arc::StateId;
201 using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
203 VectorHashStateTable<ReplaceStateTuple<StateId, PrefixId>,
204 ReplaceRootSelector<StateId, PrefixId>,
205 ReplaceFstStateFingerprint<StateId, PrefixId>,
206 ReplaceFingerprint<StateId, PrefixId>>;
207 using StackPrefix = ReplaceStackPrefix<Label, StateId>;
208 using StackPrefixTable =
209 CompactHashBiTable<PrefixId, StackPrefix,
210 ReplaceStackPrefixHash<Label, StateId>>;
212 VectorHashReplaceStateTable(
213 const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_list,
216 size_array_.push_back(0);
217 for (const auto &fst_pair : fst_list) {
218 if (fst_pair.first == root) {
219 root_size_ = CountStates(*(fst_pair.second));
220 size_array_.push_back(size_array_.back());
222 size_array_.push_back(size_array_.back() +
223 CountStates(*(fst_pair.second)));
227 new StateTable(new ReplaceRootSelector<StateId, PrefixId>,
228 new ReplaceFstStateFingerprint<StateId, PrefixId>,
229 new ReplaceFingerprint<StateId, PrefixId>(&size_array_),
230 root_size_, root_size_ + size_array_.back()));
233 VectorHashReplaceStateTable(
234 const VectorHashReplaceStateTable<Arc, PrefixId> &table)
235 : root_size_(table.root_size_),
236 size_array_(table.size_array_),
237 prefix_table_(table.prefix_table_) {
239 new StateTable(new ReplaceRootSelector<StateId, PrefixId>,
240 new ReplaceFstStateFingerprint<StateId, PrefixId>,
241 new ReplaceFingerprint<StateId, PrefixId>(&size_array_),
242 root_size_, root_size_ + size_array_.back()));
245 StateId FindState(const StateTuple &tuple) {
246 return state_table_->FindState(tuple);
249 const StateTuple &Tuple(StateId id) const { return state_table_->Tuple(id); }
251 PrefixId FindPrefixId(const StackPrefix &prefix) {
252 return prefix_table_.FindId(prefix);
255 const StackPrefix& GetStackPrefix(PrefixId id) const {
256 return prefix_table_.FindEntry(id);
261 std::vector<uint64> size_array_;
262 std::unique_ptr<StateTable> state_table_;
263 StackPrefixTable prefix_table_;
266 // Default replace state table.
267 template <class Arc, class P /* = size_t */>
268 class DefaultReplaceStateTable
269 : public CompactHashStateTable<ReplaceStateTuple<typename Arc::StateId, P>,
270 ReplaceHash<typename Arc::StateId, P>> {
272 using Label = typename Arc::Label;
273 using StateId = typename Arc::StateId;
276 using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
278 CompactHashStateTable<StateTuple, ReplaceHash<StateId, PrefixId>>;
279 using StackPrefix = ReplaceStackPrefix<Label, StateId>;
280 using StackPrefixTable =
281 CompactHashBiTable<PrefixId, StackPrefix,
282 ReplaceStackPrefixHash<Label, StateId>>;
284 using StateTable::FindState;
285 using StateTable::Tuple;
287 DefaultReplaceStateTable(
288 const std::vector<std::pair<Label, const Fst<Arc> *>> &, Label) {}
290 DefaultReplaceStateTable(const DefaultReplaceStateTable<Arc, PrefixId> &table)
291 : StateTable(), prefix_table_(table.prefix_table_) {}
293 PrefixId FindPrefixId(const StackPrefix &prefix) {
294 return prefix_table_.FindId(prefix);
297 const StackPrefix &GetStackPrefix(PrefixId id) const {
298 return prefix_table_.FindEntry(id);
302 StackPrefixTable prefix_table_;
305 // By default ReplaceFst will copy the input label of the replace arc.
306 // The call_label_type and return_label_type options specify how to manage
307 // the labels of the call arc and the return arc of the replace FST
308 template <class Arc, class StateTable = DefaultReplaceStateTable<Arc>,
309 class CacheStore = DefaultCacheStore<Arc>>
310 struct ReplaceFstOptions : CacheImplOptions<CacheStore> {
311 using Label = typename Arc::Label;
313 // Index of root rule for expansion.
315 // How to label call arc.
316 ReplaceLabelType call_label_type = REPLACE_LABEL_INPUT;
317 // How to label return arc.
318 ReplaceLabelType return_label_type = REPLACE_LABEL_NEITHER;
319 // Specifies output label to put on call arc; if kNoLabel, use existing label
320 // on call arc. Otherwise, use this field as the output label.
321 Label call_output_label = kNoLabel;
322 // Specifies label to put on return arc.
323 Label return_label = 0;
324 // Take ownership of input FSTs?
325 bool take_ownership = false;
326 // Pointer to optional pre-constructed state table.
327 StateTable *state_table = nullptr;
329 explicit ReplaceFstOptions(const CacheImplOptions<CacheStore> &opts,
330 Label root = kNoLabel)
331 : CacheImplOptions<CacheStore>(opts), root(root) {}
333 explicit ReplaceFstOptions(const CacheOptions &opts, Label root = kNoLabel)
334 : CacheImplOptions<CacheStore>(opts), root(root) {}
336 // FIXME(kbg): There are too many constructors here. Come up with a consistent
337 // position for call_output_label (probably the very end) so that it is
338 // possible to express all the remaining constructors with a single
339 // default-argument constructor. Also move clients off of the "backwards
340 // compatibility" constructor, for good.
342 explicit ReplaceFstOptions(Label root) : root(root) {}
344 explicit ReplaceFstOptions(Label root, ReplaceLabelType call_label_type,
345 ReplaceLabelType return_label_type,
348 call_label_type(call_label_type),
349 return_label_type(return_label_type),
350 return_label(return_label) {}
352 explicit ReplaceFstOptions(Label root, ReplaceLabelType call_label_type,
353 ReplaceLabelType return_label_type,
354 Label call_output_label, Label return_label)
356 call_label_type(call_label_type),
357 return_label_type(return_label_type),
358 call_output_label(call_output_label),
359 return_label(return_label) {}
361 explicit ReplaceFstOptions(const ReplaceUtilOptions &opts)
362 : ReplaceFstOptions(opts.root, opts.call_label_type,
363 opts.return_label_type, opts.return_label) {}
365 ReplaceFstOptions() : root(kNoLabel) {}
367 // For backwards compatibility.
368 ReplaceFstOptions(int64 root, bool epsilon_replace_arc)
370 call_label_type(epsilon_replace_arc ? REPLACE_LABEL_NEITHER
371 : REPLACE_LABEL_INPUT),
372 call_output_label(epsilon_replace_arc ? 0 : kNoLabel) {}
376 // Forward declaration.
377 template <class Arc, class StateTable, class CacheStore>
378 class ReplaceFstMatcher;
381 using FstList = std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>;
383 // Returns true if label type on arc results in epsilon input label.
384 inline bool EpsilonOnInput(ReplaceLabelType label_type) {
385 return label_type == REPLACE_LABEL_NEITHER ||
386 label_type == REPLACE_LABEL_OUTPUT;
389 // Returns true if label type on arc results in epsilon input label.
390 inline bool EpsilonOnOutput(ReplaceLabelType label_type) {
391 return label_type == REPLACE_LABEL_NEITHER ||
392 label_type == REPLACE_LABEL_INPUT;
395 // Returns true if for either the call or return arc ilabel != olabel.
396 template <class Label>
397 bool ReplaceTransducer(ReplaceLabelType call_label_type,
398 ReplaceLabelType return_label_type,
399 Label call_output_label) {
400 return call_label_type == REPLACE_LABEL_INPUT ||
401 call_label_type == REPLACE_LABEL_OUTPUT ||
402 (call_label_type == REPLACE_LABEL_BOTH &&
403 call_output_label != kNoLabel) ||
404 return_label_type == REPLACE_LABEL_INPUT ||
405 return_label_type == REPLACE_LABEL_OUTPUT;
409 uint64 ReplaceFstProperties(typename Arc::Label root_label,
410 const FstList<Arc> &fst_list,
411 ReplaceLabelType call_label_type,
412 ReplaceLabelType return_label_type,
413 typename Arc::Label call_output_label,
414 bool *sorted_and_non_empty) {
415 using Label = typename Arc::Label;
416 std::vector<uint64> inprops;
417 bool all_ilabel_sorted = true;
418 bool all_olabel_sorted = true;
419 bool all_non_empty = true;
420 // All nonterminals are negative?
421 bool all_negative = true;
422 // All nonterminals are positive and form a dense range containing 1?
423 bool dense_range = true;
424 Label root_fst_idx = 0;
425 for (Label i = 0; i < fst_list.size(); ++i) {
426 const auto label = fst_list[i].first;
427 if (label >= 0) all_negative = false;
428 if (label > fst_list.size() || label <= 0) dense_range = false;
429 if (label == root_label) root_fst_idx = i;
430 const auto *fst = fst_list[i].second;
431 if (fst->Start() == kNoStateId) all_non_empty = false;
432 if (!fst->Properties(kILabelSorted, false)) all_ilabel_sorted = false;
433 if (!fst->Properties(kOLabelSorted, false)) all_olabel_sorted = false;
434 inprops.push_back(fst->Properties(kCopyProperties, false));
436 const auto props = ReplaceProperties(
437 inprops, root_fst_idx, EpsilonOnInput(call_label_type),
438 EpsilonOnInput(return_label_type), EpsilonOnOutput(call_label_type),
439 EpsilonOnOutput(return_label_type),
440 ReplaceTransducer(call_label_type, return_label_type, call_output_label),
441 all_non_empty, all_ilabel_sorted, all_olabel_sorted,
442 all_negative || dense_range);
443 const bool sorted = props & (kILabelSorted | kOLabelSorted);
444 *sorted_and_non_empty = all_non_empty && sorted;
450 // The replace implementation class supports a dynamic expansion of a recursive
451 // transition network represented as label/FST pairs with dynamic replacable
453 template <class Arc, class StateTable, class CacheStore>
455 : public CacheBaseImpl<typename CacheStore::State, CacheStore> {
457 using Label = typename Arc::Label;
458 using StateId = typename Arc::StateId;
459 using Weight = typename Arc::Weight;
461 using State = typename CacheStore::State;
462 using CacheImpl = CacheBaseImpl<State, CacheStore>;
463 using PrefixId = typename StateTable::PrefixId;
464 using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
465 using StackPrefix = ReplaceStackPrefix<Label, StateId>;
466 using NonTerminalHash = std::unordered_map<Label, Label>;
468 using FstImpl<Arc>::SetType;
469 using FstImpl<Arc>::SetProperties;
470 using FstImpl<Arc>::WriteHeader;
471 using FstImpl<Arc>::SetInputSymbols;
472 using FstImpl<Arc>::SetOutputSymbols;
473 using FstImpl<Arc>::InputSymbols;
474 using FstImpl<Arc>::OutputSymbols;
476 using CacheImpl::PushArc;
477 using CacheImpl::HasArcs;
478 using CacheImpl::HasFinal;
479 using CacheImpl::HasStart;
480 using CacheImpl::SetArcs;
481 using CacheImpl::SetFinal;
482 using CacheImpl::SetStart;
484 friend class ReplaceFstMatcher<Arc, StateTable, CacheStore>;
486 ReplaceFstImpl(const FstList<Arc> &fst_list,
487 const ReplaceFstOptions<Arc, StateTable, CacheStore> &opts)
489 call_label_type_(opts.call_label_type),
490 return_label_type_(opts.return_label_type),
491 call_output_label_(opts.call_output_label),
492 return_label_(opts.return_label),
493 state_table_(opts.state_table ? opts.state_table
494 : new StateTable(fst_list, opts.root)) {
496 // If the label is epsilon, then all replace label options are equivalent,
497 // so we set the label types to NEITHER for simplicity.
498 if (call_output_label_ == 0) call_label_type_ = REPLACE_LABEL_NEITHER;
499 if (return_label_ == 0) return_label_type_ = REPLACE_LABEL_NEITHER;
500 if (!fst_list.empty()) {
501 SetInputSymbols(fst_list[0].second->InputSymbols());
502 SetOutputSymbols(fst_list[0].second->OutputSymbols());
504 fst_array_.push_back(nullptr);
505 for (Label i = 0; i < fst_list.size(); ++i) {
506 const auto label = fst_list[i].first;
507 const auto *fst = fst_list[i].second;
508 nonterminal_hash_[label] = fst_array_.size();
509 nonterminal_set_.insert(label);
510 fst_array_.emplace_back(opts.take_ownership ? fst : fst->Copy());
512 if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) {
513 FSTERROR() << "ReplaceFstImpl: Input symbols of FST " << i
514 << " do not match input symbols of base FST (0th FST)";
515 SetProperties(kError, kError);
517 if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) {
518 FSTERROR() << "ReplaceFstImpl: Output symbols of FST " << i
519 << " do not match output symbols of base FST (0th FST)";
520 SetProperties(kError, kError);
524 const auto nonterminal = nonterminal_hash_[opts.root];
525 if ((nonterminal == 0) && (fst_array_.size() > 1)) {
526 FSTERROR() << "ReplaceFstImpl: No FST corresponding to root label "
527 << opts.root << " in the input tuple vector";
528 SetProperties(kError, kError);
530 root_ = (nonterminal > 0) ? nonterminal : 1;
531 bool all_non_empty_and_sorted = false;
532 SetProperties(ReplaceFstProperties(opts.root, fst_list, call_label_type_,
533 return_label_type_, call_output_label_,
534 &all_non_empty_and_sorted));
535 // Enables optional caching as long as sorted and all non-empty.
536 always_cache_ = !all_non_empty_and_sorted;
537 VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = "
538 << (always_cache_ ? "true" : "false");
541 ReplaceFstImpl(const ReplaceFstImpl &impl)
543 call_label_type_(impl.call_label_type_),
544 return_label_type_(impl.return_label_type_),
545 call_output_label_(impl.call_output_label_),
546 return_label_(impl.return_label_),
547 always_cache_(impl.always_cache_),
548 state_table_(new StateTable(*(impl.state_table_))),
549 nonterminal_set_(impl.nonterminal_set_),
550 nonterminal_hash_(impl.nonterminal_hash_),
553 SetProperties(impl.Properties(), kCopyProperties);
554 SetInputSymbols(impl.InputSymbols());
555 SetOutputSymbols(impl.OutputSymbols());
556 fst_array_.reserve(impl.fst_array_.size());
557 fst_array_.emplace_back(nullptr);
558 for (Label i = 1; i < impl.fst_array_.size(); ++i) {
559 fst_array_.emplace_back(impl.fst_array_[i]->Copy(true));
563 // Computes the dependency graph of the replace class and returns
564 // true if the dependencies are cyclic. Cyclic dependencies will result
565 // in an un-expandable FST.
566 bool CyclicDependencies() const {
567 const ReplaceUtilOptions opts(root_);
568 ReplaceUtil<Arc> replace_util(fst_array_, nonterminal_hash_, opts);
569 return replace_util.CyclicDependencies();
574 if (fst_array_.size() == 1) {
575 SetStart(kNoStateId);
578 const auto fst_start = fst_array_[root_]->Start();
579 if (fst_start == kNoStateId) return kNoStateId;
580 const auto prefix = GetPrefixId(StackPrefix());
582 state_table_->FindState(StateTuple(prefix, root_, fst_start));
587 return CacheImpl::Start();
591 Weight Final(StateId s) {
592 if (HasFinal(s)) return CacheImpl::Final(s);
593 const auto &tuple = state_table_->Tuple(s);
594 auto weight = Weight::Zero();
595 if (tuple.prefix_id == 0) {
596 const auto fst_state = tuple.fst_state;
597 weight = fst_array_[tuple.fst_id]->Final(fst_state);
599 if (always_cache_ || HasArcs(s)) SetFinal(s, weight);
603 size_t NumArcs(StateId s) {
605 return CacheImpl::NumArcs(s);
606 } else if (always_cache_) { // If always caching, expands and caches state.
608 return CacheImpl::NumArcs(s);
609 } else { // Otherwise computes the number of arcs without expanding.
610 const auto tuple = state_table_->Tuple(s);
611 if (tuple.fst_state == kNoStateId) return 0;
612 auto num_arcs = fst_array_[tuple.fst_id]->NumArcs(tuple.fst_state);
613 if (ComputeFinalArc(tuple, nullptr)) ++num_arcs;
618 // Returns whether a given label is a non-terminal.
619 bool IsNonTerminal(Label label) const {
620 if (label < *nonterminal_set_.begin() ||
621 label > *nonterminal_set_.rbegin()) {
624 return nonterminal_hash_.count(label);
626 // TODO(allauzen): be smarter and take advantage of all_dense or
627 // all_negative. Also use this in ComputeArc. This would require changes to
628 // Replace so that recursing into an empty FST lead to a non co-accessible
629 // state instead of deleting the arc as done currently. The current use
630 // correct, since labels are sorted if all_non_empty is true.
633 size_t NumInputEpsilons(StateId s) {
635 return CacheImpl::NumInputEpsilons(s);
636 } else if (always_cache_ || !Properties(kILabelSorted)) {
637 // If always caching or if the number of input epsilons is too expensive
638 // to compute without caching (i.e., not ilabel-sorted), then expands and
641 return CacheImpl::NumInputEpsilons(s);
643 // Otherwise, computes the number of input epsilons without caching.
644 const auto tuple = state_table_->Tuple(s);
645 if (tuple.fst_state == kNoStateId) return 0;
647 if (!EpsilonOnInput(call_label_type_)) {
648 // If EpsilonOnInput(c) is false, all input epsilon arcs
649 // are also input epsilons arcs in the underlying machine.
650 num = fst_array_[tuple.fst_id]->NumInputEpsilons(tuple.fst_state);
652 // Otherwise, one need to consider that all non-terminal arcs
653 // in the underlying machine also become input epsilon arc.
654 ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
655 for (; !aiter.Done() && ((aiter.Value().ilabel == 0) ||
656 IsNonTerminal(aiter.Value().olabel));
661 if (EpsilonOnInput(return_label_type_) &&
662 ComputeFinalArc(tuple, nullptr)) {
669 size_t NumOutputEpsilons(StateId s) {
671 return CacheImpl::NumOutputEpsilons(s);
672 } else if (always_cache_ || !Properties(kOLabelSorted)) {
673 // If always caching or if the number of output epsilons is too expensive
674 // to compute without caching (i.e., not olabel-sorted), then expands and
677 return CacheImpl::NumOutputEpsilons(s);
679 // Otherwise, computes the number of output epsilons without caching.
680 const auto tuple = state_table_->Tuple(s);
681 if (tuple.fst_state == kNoStateId) return 0;
683 if (!EpsilonOnOutput(call_label_type_)) {
684 // If EpsilonOnOutput(c) is false, all output epsilon arcs are also
685 // output epsilons arcs in the underlying machine.
686 num = fst_array_[tuple.fst_id]->NumOutputEpsilons(tuple.fst_state);
688 // Otherwise, one need to consider that all non-terminal arcs in the
689 // underlying machine also become output epsilon arc.
690 ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
691 for (; !aiter.Done() && ((aiter.Value().olabel == 0) ||
692 IsNonTerminal(aiter.Value().olabel));
697 if (EpsilonOnOutput(return_label_type_) &&
698 ComputeFinalArc(tuple, nullptr)) {
705 uint64 Properties() const override { return Properties(kFstProperties); }
707 // Sets error if found, and returns other FST impl properties.
708 uint64 Properties(uint64 mask) const override {
710 for (Label i = 1; i < fst_array_.size(); ++i) {
711 if (fst_array_[i]->Properties(kError, false)) {
712 SetProperties(kError, kError);
716 return FstImpl<Arc>::Properties(mask);
719 // Returns the base arc iterator, and if arcs have not been computed yet,
720 // extends and recurses for new arcs.
721 void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
722 if (!HasArcs(s)) Expand(s);
723 CacheImpl::InitArcIterator(s, data);
724 // TODO(allauzen): Set behaviour of generic iterator.
725 // Warning: ArcIterator<ReplaceFst<A>>::InitCache() relies on current
729 // Extends current state (walk arcs one level deep).
730 void Expand(StateId s) {
731 const auto tuple = state_table_->Tuple(s);
732 if (tuple.fst_state == kNoStateId) { // Local FST is empty.
736 ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
738 // Creates a final arc when needed.
739 if (ComputeFinalArc(tuple, &arc)) PushArc(s, arc);
740 // Expands all arcs leaving the state.
741 for (; !aiter.Done(); aiter.Next()) {
742 if (ComputeArc(tuple, aiter.Value(), &arc)) PushArc(s, arc);
747 void Expand(StateId s, const StateTuple &tuple,
748 const ArcIteratorData<Arc> &data) {
749 if (tuple.fst_state == kNoStateId) { // Local FST is empty.
753 ArcIterator<Fst<Arc>> aiter(data);
755 // Creates a final arc when needed.
756 if (ComputeFinalArc(tuple, &arc)) AddArc(s, arc);
757 // Expands all arcs leaving the state.
758 for (; !aiter.Done(); aiter.Next()) {
759 if (ComputeArc(tuple, aiter.Value(), &arc)) AddArc(s, arc);
764 // If acpp is null, only returns true if a final arcp is required, but does
765 // not actually compute it.
766 bool ComputeFinalArc(const StateTuple &tuple, Arc *arcp,
767 uint32 flags = kArcValueFlags) {
768 const auto fst_state = tuple.fst_state;
769 if (fst_state == kNoStateId) return false;
770 // If state is final, pops the stack.
771 if (fst_array_[tuple.fst_id]->Final(fst_state) != Weight::Zero() &&
774 arcp->ilabel = (EpsilonOnInput(return_label_type_)) ? 0 : return_label_;
776 (EpsilonOnOutput(return_label_type_)) ? 0 : return_label_;
777 if (flags & kArcNextStateValue) {
778 const auto &stack = state_table_->GetStackPrefix(tuple.prefix_id);
779 const auto prefix_id = PopPrefix(stack);
780 const auto &top = stack.Top();
781 arcp->nextstate = state_table_->FindState(
782 StateTuple(prefix_id, top.fst_id, top.nextstate));
784 if (flags & kArcWeightValue) {
785 arcp->weight = fst_array_[tuple.fst_id]->Final(fst_state);
794 // Computes an arc in the FST corresponding to one in the underlying machine.
795 // Returns false if the underlying arc corresponds to no arc in the resulting
797 bool ComputeArc(const StateTuple &tuple, const Arc &arc, Arc *arcp,
798 uint32 flags = kArcValueFlags) {
799 if (!EpsilonOnInput(call_label_type_) &&
800 (flags == (flags & (kArcILabelValue | kArcWeightValue)))) {
804 if (arc.olabel == 0 || arc.olabel < *nonterminal_set_.begin() ||
805 arc.olabel > *nonterminal_set_.rbegin()) { // Expands local FST.
806 const auto nextstate =
807 flags & kArcNextStateValue
808 ? state_table_->FindState(
809 StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
811 *arcp = Arc(arc.ilabel, arc.olabel, arc.weight, nextstate);
813 // Checks for non-terminal.
814 const auto it = nonterminal_hash_.find(arc.olabel);
815 if (it != nonterminal_hash_.end()) { // Recurses into non-terminal.
816 const auto nonterminal = it->second;
817 const auto nt_prefix =
818 PushPrefix(state_table_->GetStackPrefix(tuple.prefix_id),
819 tuple.fst_id, arc.nextstate);
820 // If the start state is valid, replace; othewise, the arc is implicitly
822 const auto nt_start = fst_array_[nonterminal]->Start();
823 if (nt_start != kNoStateId) {
824 const auto nt_nextstate = flags & kArcNextStateValue
825 ? state_table_->FindState(StateTuple(
826 nt_prefix, nonterminal, nt_start))
829 (EpsilonOnInput(call_label_type_)) ? 0 : arc.ilabel;
831 (EpsilonOnOutput(call_label_type_))
833 : ((call_output_label_ == kNoLabel) ? arc.olabel
834 : call_output_label_);
835 *arcp = Arc(ilabel, olabel, arc.weight, nt_nextstate);
840 const auto nextstate =
841 flags & kArcNextStateValue
842 ? state_table_->FindState(
843 StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
845 *arcp = Arc(arc.ilabel, arc.olabel, arc.weight, nextstate);
851 // Returns the arc iterator flags supported by this FST.
852 uint32 ArcIteratorFlags() const {
853 uint32 flags = kArcValueFlags;
854 if (!always_cache_) flags |= kArcNoCache;
858 StateTable *GetStateTable() const { return state_table_.get(); }
860 const Fst<Arc> *GetFst(Label fst_id) const {
861 return fst_array_[fst_id].get();
864 Label GetFstId(Label nonterminal) const {
865 const auto it = nonterminal_hash_.find(nonterminal);
866 if (it == nonterminal_hash_.end()) {
867 FSTERROR() << "ReplaceFstImpl::GetFstId: Nonterminal not found: "
873 // Returns true if label type on call arc results in epsilon input label.
874 bool EpsilonOnCallInput() { return EpsilonOnInput(call_label_type_); }
877 // The unique index into stack prefix table.
878 PrefixId GetPrefixId(const StackPrefix &prefix) {
879 return state_table_->FindPrefixId(prefix);
882 // The prefix ID after a stack pop.
883 PrefixId PopPrefix(StackPrefix prefix) {
885 return GetPrefixId(prefix);
888 // The prefix ID after a stack push.
889 PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
890 prefix.Push(fst_id, nextstate);
891 return GetPrefixId(prefix);
895 ReplaceLabelType call_label_type_; // How to label call arc.
896 ReplaceLabelType return_label_type_; // How to label return arc.
897 int64 call_output_label_; // Specifies output label to put on call arc
898 int64 return_label_; // Specifies label to put on return arc.
899 bool always_cache_; // Disable optional caching of arc iterator?
902 std::unique_ptr<StateTable> state_table_;
904 // Replace components.
905 std::set<Label> nonterminal_set_;
906 NonTerminalHash nonterminal_hash_;
907 std::vector<std::unique_ptr<const Fst<Arc>>> fst_array_;
911 } // namespace internal
914 // ReplaceFst supports dynamic replacement of arcs in one FST with another FST.
915 // This replacement is recursive. ReplaceFst can be used to support a variety of
916 // delayed constructions such as recursive
917 // transition networks, union, or closure. It is constructed with an array of
918 // FST(s). One FST represents the root (or topology) machine. The root FST
919 // refers to other FSTs by recursively replacing arcs labeled as non-terminals
920 // with the matching non-terminal FST. Currently the ReplaceFst uses the output
921 // symbols of the arcs to determine whether the arc is a non-terminal arc or
922 // not. A non-terminal can be any label that is not a non-zero terminal label in
923 // the output alphabet.
925 // Note that the constructor uses a vector of pairs. These correspond to the
926 // tuple of non-terminal Label and corresponding FST. For example to implement
927 // the closure operation we need 2 FSTs. The first root FST is a single
928 // self-loop arc on the start state.
930 // The ReplaceFst class supports an optionally caching arc iterator.
932 // The ReplaceFst needs to be built such that it is known to be ilabel- or
933 // olabel-sorted (see usage below).
935 // Observe that Matcher<Fst<A>> will use the optionally caching arc iterator
936 // when available (the FST is ilabel-sorted and matching on the input, or the
937 // FST is olabel -orted and matching on the output). In order to obtain the
938 // most efficient behaviour, it is recommended to set call_label_type to
939 // REPLACE_LABEL_INPUT or REPLACE_LABEL_BOTH and return_label_type to
940 // REPLACE_LABEL_OUTPUT or REPLACE_LABEL_NEITHER. This means that the call arc
941 // does not have epsilon on the input side and the return arc has epsilon on the
942 // input side) and matching on the input side.
944 // This class attaches interface to implementation and handles reference
945 // counting, delegating most methods to ImplToFst.
946 template <class A, class T /* = DefaultReplaceStateTable<A> */,
947 class CacheStore /* = DefaultCacheStore<A> */>
949 : public ImplToFst<internal::ReplaceFstImpl<A, T, CacheStore>> {
952 using Label = typename Arc::Label;
953 using StateId = typename Arc::StateId;
954 using Weight = typename Arc::Weight;
956 using StateTable = T;
957 using Store = CacheStore;
958 using State = typename CacheStore::State;
959 using Impl = internal::ReplaceFstImpl<Arc, StateTable, CacheStore>;
960 using CacheImpl = internal::CacheBaseImpl<State, CacheStore>;
962 using ImplToFst<Impl>::Properties;
964 friend class ArcIterator<ReplaceFst<Arc, StateTable, CacheStore>>;
965 friend class StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>;
966 friend class ReplaceFstMatcher<Arc, StateTable, CacheStore>;
968 ReplaceFst(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_array,
970 : ImplToFst<Impl>(std::make_shared<Impl>(
971 fst_array, ReplaceFstOptions<Arc, StateTable, CacheStore>(root))) {}
973 ReplaceFst(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_array,
974 const ReplaceFstOptions<Arc, StateTable, CacheStore> &opts)
975 : ImplToFst<Impl>(std::make_shared<Impl>(fst_array, opts)) {}
977 // See Fst<>::Copy() for doc.
978 ReplaceFst(const ReplaceFst<Arc, StateTable, CacheStore> &fst,
980 : ImplToFst<Impl>(fst, safe) {}
982 // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc.
983 ReplaceFst<Arc, StateTable, CacheStore> *Copy(
984 bool safe = false) const override {
985 return new ReplaceFst<Arc, StateTable, CacheStore>(*this, safe);
988 inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
990 void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
991 GetMutableImpl()->InitArcIterator(s, data);
994 MatcherBase<Arc> *InitMatcher(MatchType match_type) const override {
995 if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
996 ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) ||
997 (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) {
998 return new ReplaceFstMatcher<Arc, StateTable, CacheStore>
1001 VLOG(2) << "Not using replace matcher";
1006 bool CyclicDependencies() const { return GetImpl()->CyclicDependencies(); }
1008 const StateTable &GetStateTable() const {
1009 return *GetImpl()->GetStateTable();
1012 const Fst<Arc> &GetFst(Label nonterminal) const {
1013 return *GetImpl()->GetFst(GetImpl()->GetFstId(nonterminal));
1017 using ImplToFst<Impl>::GetImpl;
1018 using ImplToFst<Impl>::GetMutableImpl;
1020 ReplaceFst &operator=(const ReplaceFst &) = delete;
1023 // Specialization for ReplaceFst.
1024 template <class Arc, class StateTable, class CacheStore>
1025 class StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>
1026 : public CacheStateIterator<ReplaceFst<Arc, StateTable, CacheStore>> {
1028 explicit StateIterator(const ReplaceFst<Arc, StateTable, CacheStore> &fst)
1029 : CacheStateIterator<ReplaceFst<Arc, StateTable, CacheStore>>(
1030 fst, fst.GetMutableImpl()) {}
1033 // Specialization for ReplaceFst, implementing optional caching. It is be used
1036 // ReplaceFst<A> replace;
1037 // ArcIterator<ReplaceFst<A>> aiter(replace, s);
1038 // // Note: ArcIterator< Fst<A>> is always a caching arc iterator.
1039 // aiter.SetFlags(kArcNoCache, kArcNoCache);
1040 // // Uses the arc iterator, no arc will be cached, no state will be expanded.
1041 // // Arc flags can be used to decide which component of the arc need to be
1043 // aiter.SetFlags(kArcILabelValue, kArcValueFlags);
1044 // // Wants the ilabel for this arc.
1045 // aiter.Value(); // Does not compute the destination state.
1047 // aiter.SetFlags(kArcNextStateValue, kArcNextStateValue);
1048 // // Wants the ilabel and next state for this arc.
1049 // aiter.Value(); // Does compute the destination state and inserts it
1050 // // in the replace state table.
1051 // // No additional arcs have been cached at this point.
1052 template <class Arc, class StateTable, class CacheStore>
1053 class ArcIterator<ReplaceFst<Arc, StateTable, CacheStore>> {
1055 using StateId = typename Arc::StateId;
1057 using StateTuple = typename StateTable::StateTuple;
1059 ArcIterator(const ReplaceFst<Arc, StateTable, CacheStore> &fst, StateId s)
1064 flags_(kArcValueFlags),
1068 cache_data_.ref_count = nullptr;
1069 local_data_.ref_count = nullptr;
1070 // If FST does not support optional caching, forces caching.
1071 if (!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
1072 !(fst_.GetImpl()->HasArcs(s_))) {
1073 fst_.GetMutableImpl()->Expand(s_);
1075 // If state is already cached, use cached arcs array.
1076 if (fst_.GetImpl()->HasArcs(s_)) {
1078 ->internal::template CacheBaseImpl<
1079 typename CacheStore::State,
1080 CacheStore>::InitArcIterator(s_, &cache_data_);
1081 num_arcs_ = cache_data_.narcs;
1082 arcs_ = cache_data_.arcs; // arcs_ is a pointer to the cached arcs.
1083 data_flags_ = kArcValueFlags; // All the arc member values are valid.
1084 } else { // Otherwise delay decision until Value() is called.
1085 tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(s_);
1086 if (tuple_.fst_state == kNoStateId) {
1089 // The decision to cache or not to cache has been defered until Value()
1091 // SetFlags() is called. However, the arc iterator is set up now to be
1092 // ready for non-caching in order to keep the Value() method simple and
1094 const auto *rfst = fst_.GetImpl()->GetFst(tuple_.fst_id);
1095 rfst->InitArcIterator(tuple_.fst_state, &local_data_);
1096 // arcs_ is a pointer to the arcs in the underlying machine.
1097 arcs_ = local_data_.arcs;
1098 // Computes the final arc (but not its destination state) if a final arc
1100 bool has_final_arc = fst_.GetMutableImpl()->ComputeFinalArc(
1101 tuple_, &final_arc_, kArcValueFlags & ~kArcNextStateValue);
1102 // Sets the arc value flags that hold for final_arc_.
1103 final_flags_ = kArcValueFlags & ~kArcNextStateValue;
1104 // Computes the number of arcs.
1105 num_arcs_ = local_data_.narcs;
1106 if (has_final_arc) ++num_arcs_;
1107 // Sets the offset between the underlying arc positions and the
1109 // in the arc iterator.
1110 offset_ = num_arcs_ - local_data_.narcs;
1111 // Defers the decision to cache or not until Value() or SetFlags() is
1119 if (cache_data_.ref_count) --(*cache_data_.ref_count);
1120 if (local_data_.ref_count) --(*local_data_.ref_count);
1123 void ExpandAndCache() const {
1124 // TODO(allauzen): revisit this.
1125 // fst_.GetImpl()->Expand(s_, tuple_, local_data_);
1126 // (fst_.GetImpl())->CacheImpl<A>*>::InitArcIterator(s_,
1129 fst_.InitArcIterator(s_, &cache_data_); // Expand and cache state.
1130 arcs_ = cache_data_.arcs; // arcs_ is a pointer to the cached arcs.
1131 data_flags_ = kArcValueFlags; // All the arc member values are valid.
1132 offset_ = 0; // No offset.
1136 if (flags_ & kArcNoCache) { // If caching is disabled
1137 // arcs_ is a pointer to the arcs in the underlying machine.
1138 arcs_ = local_data_.arcs;
1139 // Sets the arcs value flags that hold for arcs_.
1140 data_flags_ = kArcWeightValue;
1141 if (!fst_.GetMutableImpl()->EpsilonOnCallInput()) {
1142 data_flags_ |= kArcILabelValue;
1144 // Sets the offset between the underlying arc positions and the positions
1145 // in the arc iterator.
1146 offset_ = num_arcs_ - local_data_.narcs;
1152 bool Done() const { return pos_ >= num_arcs_; }
1154 const Arc &Value() const {
1155 // If data_flags_ is 0, non-caching was not requested.
1157 // TODO(allauzen): Revisit this.
1158 if (flags_ & kArcNoCache) {
1159 // Should never happen.
1160 FSTERROR() << "ReplaceFst: Inconsistent arc iterator flags";
1164 if (pos_ - offset_ >= 0) { // The requested arc is not the final arc.
1165 const auto &arc = arcs_[pos_ - offset_];
1166 if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) {
1167 // If the value flags match the recquired value flags then returns the
1171 // Otherwise, compute the corresponding arc on-the-fly.
1172 fst_.GetMutableImpl()->ComputeArc(tuple_, arc, &arc_,
1173 flags_ & kArcValueFlags);
1176 } else { // The requested arc is the final arc.
1177 if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) {
1178 // If the arc value flags that hold for the final arc do not match the
1179 // requested value flags, then
1180 // final_arc_ needs to be updated.
1181 fst_.GetMutableImpl()->ComputeFinalArc(tuple_, &final_arc_,
1182 flags_ & kArcValueFlags);
1183 final_flags_ = flags_ & kArcValueFlags;
1189 void Next() { ++pos_; }
1191 size_t Position() const { return pos_; }
1193 void Reset() { pos_ = 0; }
1195 void Seek(size_t pos) { pos_ = pos; }
1197 uint32 Flags() const { return flags_; }
1199 void SetFlags(uint32 flags, uint32 mask) {
1200 // Updates the flags taking into account what flags are supported
1203 flags_ |= (flags & fst_.GetImpl()->ArcIteratorFlags());
1204 // If non-caching is not requested (and caching has not already been
1205 // performed), then flush data_flags_ to request caching during the next
1207 if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) {
1208 if (!fst_.GetImpl()->HasArcs(s_)) data_flags_ = 0;
1210 // If data_flags_ has been flushed but non-caching is requested before
1211 // calling Value(), then set up the iterator for non-caching.
1212 if ((flags & kArcNoCache) && (!data_flags_)) Init();
1216 const ReplaceFst<Arc, StateTable, CacheStore> &fst_; // Reference to the FST.
1217 StateId s_; // State in the FST.
1218 mutable StateTuple tuple_; // Tuple corresponding to state_.
1220 ssize_t pos_; // Current position.
1221 mutable ssize_t offset_; // Offset between position in iterator and in arcs_.
1222 ssize_t num_arcs_; // Number of arcs at state_.
1223 uint32 flags_; // Behavorial flags for the arc iterator
1224 mutable Arc arc_; // Memory to temporarily store computed arcs.
1226 mutable ArcIteratorData<Arc> cache_data_; // Arc iterator data in cache.
1227 mutable ArcIteratorData<Arc> local_data_; // Arc iterator data in local FST.
1229 mutable const Arc *arcs_; // Array of arcs.
1230 mutable uint32 data_flags_; // Arc value flags valid for data in arcs_.
1231 mutable Arc final_arc_; // Final arc (when required).
1232 mutable uint32 final_flags_; // Arc value flags valid for final_arc_.
1234 ArcIterator(const ArcIterator &) = delete;
1235 ArcIterator &operator=(const ArcIterator &) = delete;
1238 template <class Arc, class StateTable, class CacheStore>
1239 class ReplaceFstMatcher : public MatcherBase<Arc> {
1241 using Label = typename Arc::Label;
1242 using StateId = typename Arc::StateId;
1243 using Weight = typename Arc::Weight;
1245 using FST = ReplaceFst<Arc, StateTable, CacheStore>;
1246 using LocalMatcher = MultiEpsMatcher<Matcher<Fst<Arc>>>;
1248 using StateTuple = typename StateTable::StateTuple;
1250 // This makes a copy of the FST.
1251 ReplaceFstMatcher(const ReplaceFst<Arc, StateTable, CacheStore> &fst,
1252 MatchType match_type)
1253 : owned_fst_(fst.Copy()),
1255 impl_(fst_.GetMutableImpl()),
1256 s_(fst::kNoStateId),
1257 match_type_(match_type),
1258 current_loop_(false),
1260 loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
1261 if (match_type_ == fst::MATCH_OUTPUT) {
1262 std::swap(loop_.ilabel, loop_.olabel);
1267 // This doesn't copy the FST.
1268 ReplaceFstMatcher(const ReplaceFst<Arc, StateTable, CacheStore> *fst,
1269 MatchType match_type)
1271 impl_(fst_.GetMutableImpl()),
1272 s_(fst::kNoStateId),
1273 match_type_(match_type),
1274 current_loop_(false),
1276 loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
1277 if (match_type_ == fst::MATCH_OUTPUT) {
1278 std::swap(loop_.ilabel, loop_.olabel);
1283 // This makes a copy of the FST.
1285 const ReplaceFstMatcher<Arc, StateTable, CacheStore> &matcher,
1287 : owned_fst_(matcher.fst_.Copy(safe)),
1289 impl_(fst_.GetMutableImpl()),
1290 s_(fst::kNoStateId),
1291 match_type_(matcher.match_type_),
1292 current_loop_(false),
1294 loop_(fst::kNoLabel, 0, Weight::One(), fst::kNoStateId) {
1295 if (match_type_ == fst::MATCH_OUTPUT) {
1296 std::swap(loop_.ilabel, loop_.olabel);
1301 // Creates a local matcher for each component FST in the RTN. LocalMatcher is
1302 // a multi-epsilon wrapper matcher. MultiEpsilonMatcher is used to match each
1303 // non-terminal arc, since these non-terminal
1304 // turn into epsilons on recursion.
1305 void InitMatchers() {
1306 const auto &fst_array = impl_->fst_array_;
1307 matcher_.resize(fst_array.size());
1308 for (Label i = 0; i < fst_array.size(); ++i) {
1311 new LocalMatcher(*fst_array[i], match_type_, kMultiEpsList));
1312 auto it = impl_->nonterminal_set_.begin();
1313 for (; it != impl_->nonterminal_set_.end(); ++it) {
1314 matcher_[i]->AddMultiEpsLabel(*it);
1320 ReplaceFstMatcher<Arc, StateTable, CacheStore> *Copy(
1321 bool safe = false) const override {
1322 return new ReplaceFstMatcher<Arc, StateTable, CacheStore>(*this, safe);
1325 MatchType Type(bool test) const override {
1326 if (match_type_ == MATCH_NONE) return match_type_;
1327 const auto true_prop =
1328 match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted;
1329 const auto false_prop =
1330 match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted;
1331 const auto props = fst_.Properties(true_prop | false_prop, test);
1332 if (props & true_prop) {
1334 } else if (props & false_prop) {
1337 return MATCH_UNKNOWN;
1341 const Fst<Arc> &GetFst() const override { return fst_; }
1343 uint64 Properties(uint64 props) const override { return props; }
1345 // Sets the state from which our matching happens.
1346 void SetState(StateId s) final {
1347 if (s_ == s) return;
1349 tuple_ = impl_->GetStateTable()->Tuple(s_);
1350 if (tuple_.fst_state == kNoStateId) {
1354 // Gets current matcher, used for non-epsilon matching.
1355 current_matcher_ = matcher_[tuple_.fst_id].get();
1356 current_matcher_->SetState(tuple_.fst_state);
1357 loop_.nextstate = s_;
1361 // Searches for label from previous set state. If label == 0, first
1362 // hallucinate an epsilon loop; otherwise use the underlying matcher to
1363 // search for the label or epsilons. Note since the ReplaceFst recursion
1364 // on non-terminal arcs causes epsilon transitions to be created we use
1365 // MultiEpsilonMatcher to search for possible matches of non-terminals. If the
1367 // reaches a final state we also need to add the exiting final arc.
1368 bool Find(Label label) final {
1371 if (label_ == 0 || label_ == kNoLabel) {
1372 // Computes loop directly, avoiding Replace::ComputeArc.
1374 current_loop_ = true;
1377 // Searches for matching multi-epsilons.
1378 final_arc_ = impl_->ComputeFinalArc(tuple_, nullptr);
1379 found = current_matcher_->Find(kNoLabel) || final_arc_ || found;
1381 // Searches on a sub machine directly using sub machine matcher.
1382 found = current_matcher_->Find(label_);
1387 bool Done() const final {
1388 return !current_loop_ && !final_arc_ && current_matcher_->Done();
1391 const Arc &Value() const final {
1392 if (current_loop_) return loop_;
1394 impl_->ComputeFinalArc(tuple_, &arc_);
1397 const auto &component_arc = current_matcher_->Value();
1398 impl_->ComputeArc(tuple_, component_arc, &arc_);
1403 if (current_loop_) {
1404 current_loop_ = false;
1411 current_matcher_->Next();
1414 ssize_t Priority(StateId s) final { return fst_.NumArcs(s); }
1417 std::unique_ptr<const ReplaceFst<Arc, StateTable, CacheStore>> owned_fst_;
1418 const ReplaceFst<Arc, StateTable, CacheStore> &fst_;
1419 internal::ReplaceFstImpl<Arc, StateTable, CacheStore> *impl_;
1420 LocalMatcher *current_matcher_;
1421 std::vector<std::unique_ptr<LocalMatcher>> matcher_;
1422 StateId s_; // Current state.
1423 Label label_; // Current label.
1424 MatchType match_type_; // Supplied by caller.
1426 mutable bool current_loop_; // Current arc is the implicit loop.
1427 mutable bool final_arc_; // Current arc for exiting recursion.
1428 mutable StateTuple tuple_; // Tuple corresponding to state_.
1432 ReplaceFstMatcher &operator=(const ReplaceFstMatcher &) = delete;
1435 template <class Arc, class StateTable, class CacheStore>
1436 inline void ReplaceFst<Arc, StateTable, CacheStore>::InitStateIterator(
1437 StateIteratorData<Arc> *data) const {
1439 new StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>(*this);
1442 using StdReplaceFst = ReplaceFst<StdArc>;
1444 // Recursively replaces arcs in the root FSTs with other FSTs.
1445 // This version writes the result of replacement to an output MutableFst.
1447 // Replace supports replacement of arcs in one Fst with another FST. This
1448 // replacement is recursive. Replace takes an array of FST(s). One FST
1449 // represents the root (or topology) machine. The root FST refers to other FSTs
1450 // by recursively replacing arcs labeled as non-terminals with the matching
1451 // non-terminal FST. Currently Replace uses the output symbols of the arcs to
1452 // determine whether the arc is a non-terminal arc or not. A non-terminal can be
1453 // any label that is not a non-zero terminal label in the output alphabet.
1455 // Note that input argument is a vector of pairs. These correspond to the tuple
1456 // of non-terminal Label and corresponding FST.
1457 template <class Arc>
1458 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1460 MutableFst<Arc> *ofst,
1461 ReplaceFstOptions<Arc> opts = ReplaceFstOptions<Arc>()) {
1463 opts.gc_limit = 0; // Caches only the last state for fastest copy.
1464 *ofst = ReplaceFst<Arc>(ifst_array, opts);
1467 template <class Arc>
1468 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1470 MutableFst<Arc> *ofst, const ReplaceUtilOptions &opts) {
1471 Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(opts));
1474 // For backwards compatibility.
1475 template <class Arc>
1476 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1478 MutableFst<Arc> *ofst, typename Arc::Label root,
1479 bool epsilon_on_replace) {
1480 Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root, epsilon_on_replace));
1483 template <class Arc>
1484 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1486 MutableFst<Arc> *ofst, typename Arc::Label root) {
1487 Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root));
1492 #endif // FST_REPLACE_H_