54a65627edc568ecff0e39a2c1b400e4db90ad95
[platform/upstream/openfst.git] / src / include / fst / replace.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Functions and classes for the recursive replacement of FSTs.
5
6 #ifndef FST_REPLACE_H_
7 #define FST_REPLACE_H_
8
9 #include <set>
10 #include <string>
11 #include <unordered_map>
12 #include <utility>
13 #include <vector>
14
15 #include <fst/log.h>
16
17 #include <fst/cache.h>
18 #include <fst/expanded-fst.h>
19 #include <fst/fst-decl.h>  // For optional argument declarations.
20 #include <fst/fst.h>
21 #include <fst/matcher.h>
22 #include <fst/replace-util.h>
23 #include <fst/state-table.h>
24 #include <fst/test-properties.h>
25
26 namespace fst {
27
28 // Replace state tables have the form:
29 //
30 // template <class Arc, class P>
31 // class ReplaceStateTable {
32 //  public:
33 //   using Label = typename Arc::Label Label;
34 //   using StateId = typename Arc::StateId;
35 //
36 //   using PrefixId = P;
37 //   using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
38 //   using StackPrefix = ReplaceStackPrefix<Label, StateId>;
39 //
40 //   // Required constructor.
41 //   ReplaceStateTable(
42 //       const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_list,
43 //       Label root);
44 //
45 //   // Required copy constructor that does not copy state.
46 //   ReplaceStateTable(const ReplaceStateTable<Arc, PrefixId> &table);
47 //
48 //   // Looks up state ID by tuple, adding it if it doesn't exist.
49 //   StateId FindState(const StateTuple &tuple);
50 //
51 //   // Looks up state tuple by ID.
52 //   const StateTuple &Tuple(StateId id) const;
53 //
54 //   // Lookus up prefix ID by stack prefix, adding it if it doesn't exist.
55 //   PrefixId FindPrefixId(const StackPrefix &stack_prefix);
56 //
57 //  // Looks up stack prefix by ID.
58 //  const StackPrefix &GetStackPrefix(PrefixId id) const;
59 // };
60
61 // Tuple that uniquely defines a state in replace.
62 template <class S, class P>
63 struct ReplaceStateTuple {
64   using StateId = S;
65   using PrefixId = P;
66
67   ReplaceStateTuple(PrefixId prefix_id = -1, StateId fst_id = kNoStateId,
68                     StateId fst_state = kNoStateId)
69       : prefix_id(prefix_id), fst_id(fst_id), fst_state(fst_state) {}
70
71   PrefixId prefix_id;  // Index in prefix table.
72   StateId fst_id;      // Current FST being walked.
73   StateId fst_state;   // Current state in FST being walked (not to be
74                        // confused with the thse StateId of the combined FST).
75 };
76
77 // Equality of replace state tuples.
78 template <class StateId, class PrefixId>
79 inline bool operator==(const ReplaceStateTuple<StateId, PrefixId> &x,
80                        const ReplaceStateTuple<StateId, PrefixId> &y) {
81   return x.prefix_id == y.prefix_id && x.fst_id == y.fst_id &&
82          x.fst_state == y.fst_state;
83 }
84
85 // Functor returning true for tuples corresponding to states in the root FST.
86 template <class StateId, class PrefixId>
87 class ReplaceRootSelector {
88  public:
89   bool operator()(const ReplaceStateTuple<StateId, PrefixId> &tuple) const {
90     return tuple.prefix_id == 0;
91   }
92 };
93
94 // Functor for fingerprinting replace state tuples.
95 template <class StateId, class PrefixId>
96 class ReplaceFingerprint {
97  public:
98   explicit ReplaceFingerprint(const std::vector<uint64> *size_array)
99       : size_array_(size_array) {}
100
101   uint64 operator()(const ReplaceStateTuple<StateId, PrefixId> &tuple) const {
102     return tuple.prefix_id * size_array_->back() +
103            size_array_->at(tuple.fst_id - 1) + tuple.fst_state;
104   }
105
106  private:
107   const std::vector<uint64> *size_array_;
108 };
109
110 // Useful when the fst_state uniquely define the tuple.
111 template <class StateId, class PrefixId>
112 class ReplaceFstStateFingerprint {
113  public:
114   uint64 operator()(const ReplaceStateTuple<StateId, PrefixId> &tuple) const {
115     return tuple.fst_state;
116   }
117 };
118
119 // A generic hash function for replace state tuples.
120 template <typename S, typename P>
121 class ReplaceHash {
122  public:
123   size_t operator()(const ReplaceStateTuple<S, P>& t) const {
124     static constexpr auto prime0 = 7853;
125     static constexpr auto prime1 = 7867;
126     return t.prefix_id + t.fst_id * prime0 + t.fst_state * prime1;
127   }
128 };
129
130 // Container for stack prefix.
131 template <class Label, class StateId>
132 class ReplaceStackPrefix {
133  public:
134   struct PrefixTuple {
135     PrefixTuple(Label fst_id = kNoLabel, StateId nextstate = kNoStateId)
136         : fst_id(fst_id), nextstate(nextstate) {}
137
138     Label fst_id;
139     StateId nextstate;
140   };
141
142   ReplaceStackPrefix() {}
143
144   ReplaceStackPrefix(const ReplaceStackPrefix &other)
145       : prefix_(other.prefix_) {}
146
147   void Push(StateId fst_id, StateId nextstate) {
148     prefix_.push_back(PrefixTuple(fst_id, nextstate));
149   }
150
151   void Pop() { prefix_.pop_back(); }
152
153   const PrefixTuple &Top() const { return prefix_[prefix_.size() - 1]; }
154
155   size_t Depth() const { return prefix_.size(); }
156
157  public:
158   std::vector<PrefixTuple> prefix_;
159 };
160
161 // Equality stack prefix classes.
162 template <class Label, class StateId>
163 inline bool operator==(const ReplaceStackPrefix<Label, StateId> &x,
164                        const ReplaceStackPrefix<Label, StateId> &y) {
165   if (x.prefix_.size() != y.prefix_.size()) return false;
166   for (size_t i = 0; i < x.prefix_.size(); ++i) {
167     if (x.prefix_[i].fst_id != y.prefix_[i].fst_id ||
168         x.prefix_[i].nextstate != y.prefix_[i].nextstate) {
169       return false;
170     }
171   }
172   return true;
173 }
174
175 // Hash function for stack prefix to prefix id.
176 template <class Label, class StateId>
177 class ReplaceStackPrefixHash {
178  public:
179   size_t operator()(const ReplaceStackPrefix<Label, StateId> &prefix) const {
180     size_t sum = 0;
181     for (const auto &pair : prefix.prefix_) {
182       static constexpr auto prime = 7863;
183       sum += pair.fst_id + pair.nextstate * prime;
184     }
185     return sum;
186   }
187 };
188
189 // Replace state tables.
190
191 // A two-level state table for replace. Warning: calls CountStates to compute
192 // the number of states of each component FST.
193 template <class Arc, class P = ssize_t>
194 class VectorHashReplaceStateTable {
195  public:
196   using Label = typename Arc::Label;
197   using StateId = typename Arc::StateId;
198
199   using PrefixId = P;
200
201   using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
202   using StateTable =
203       VectorHashStateTable<ReplaceStateTuple<StateId, PrefixId>,
204                            ReplaceRootSelector<StateId, PrefixId>,
205                            ReplaceFstStateFingerprint<StateId, PrefixId>,
206                            ReplaceFingerprint<StateId, PrefixId>>;
207   using StackPrefix = ReplaceStackPrefix<Label, StateId>;
208   using StackPrefixTable =
209       CompactHashBiTable<PrefixId, StackPrefix,
210                          ReplaceStackPrefixHash<Label, StateId>>;
211
212   VectorHashReplaceStateTable(
213       const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_list,
214       Label root)
215       : root_size_(0) {
216     size_array_.push_back(0);
217     for (const auto &fst_pair : fst_list) {
218       if (fst_pair.first == root) {
219         root_size_ = CountStates(*(fst_pair.second));
220         size_array_.push_back(size_array_.back());
221       } else {
222         size_array_.push_back(size_array_.back() +
223                               CountStates(*(fst_pair.second)));
224       }
225     }
226     state_table_.reset(
227         new StateTable(new ReplaceRootSelector<StateId, PrefixId>,
228                        new ReplaceFstStateFingerprint<StateId, PrefixId>,
229                        new ReplaceFingerprint<StateId, PrefixId>(&size_array_),
230                        root_size_, root_size_ + size_array_.back()));
231   }
232
233   VectorHashReplaceStateTable(
234       const VectorHashReplaceStateTable<Arc, PrefixId> &table)
235       : root_size_(table.root_size_),
236         size_array_(table.size_array_),
237         prefix_table_(table.prefix_table_) {
238     state_table_.reset(
239         new StateTable(new ReplaceRootSelector<StateId, PrefixId>,
240                        new ReplaceFstStateFingerprint<StateId, PrefixId>,
241                        new ReplaceFingerprint<StateId, PrefixId>(&size_array_),
242                        root_size_, root_size_ + size_array_.back()));
243   }
244
245   StateId FindState(const StateTuple &tuple) {
246     return state_table_->FindState(tuple);
247   }
248
249   const StateTuple &Tuple(StateId id) const { return state_table_->Tuple(id); }
250
251   PrefixId FindPrefixId(const StackPrefix &prefix) {
252     return prefix_table_.FindId(prefix);
253   }
254
255   const StackPrefix& GetStackPrefix(PrefixId id) const {
256     return prefix_table_.FindEntry(id);
257   }
258
259  private:
260   StateId root_size_;
261   std::vector<uint64> size_array_;
262   std::unique_ptr<StateTable> state_table_;
263   StackPrefixTable prefix_table_;
264 };
265
266 // Default replace state table.
267 template <class Arc, class P /* = size_t */>
268 class DefaultReplaceStateTable
269     : public CompactHashStateTable<ReplaceStateTuple<typename Arc::StateId, P>,
270                                    ReplaceHash<typename Arc::StateId, P>> {
271  public:
272   using Label = typename Arc::Label;
273   using StateId = typename Arc::StateId;
274
275   using PrefixId = P;
276   using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
277   using StateTable =
278       CompactHashStateTable<StateTuple, ReplaceHash<StateId, PrefixId>>;
279   using StackPrefix = ReplaceStackPrefix<Label, StateId>;
280   using StackPrefixTable =
281       CompactHashBiTable<PrefixId, StackPrefix,
282                          ReplaceStackPrefixHash<Label, StateId>>;
283
284   using StateTable::FindState;
285   using StateTable::Tuple;
286
287   DefaultReplaceStateTable(
288       const std::vector<std::pair<Label, const Fst<Arc> *>> &, Label) {}
289
290   DefaultReplaceStateTable(const DefaultReplaceStateTable<Arc, PrefixId> &table)
291       : StateTable(), prefix_table_(table.prefix_table_) {}
292
293   PrefixId FindPrefixId(const StackPrefix &prefix) {
294     return prefix_table_.FindId(prefix);
295   }
296
297   const StackPrefix &GetStackPrefix(PrefixId id) const {
298     return prefix_table_.FindEntry(id);
299   }
300
301  private:
302   StackPrefixTable prefix_table_;
303 };
304
305 // By default ReplaceFst will copy the input label of the replace arc.
306 // The call_label_type and return_label_type options specify how to manage
307 // the labels of the call arc and the return arc of the replace FST
308 template <class Arc, class StateTable = DefaultReplaceStateTable<Arc>,
309           class CacheStore = DefaultCacheStore<Arc>>
310 struct ReplaceFstOptions : CacheImplOptions<CacheStore> {
311   using Label = typename Arc::Label;
312
313   // Index of root rule for expansion.
314   Label root;
315   // How to label call arc.
316   ReplaceLabelType call_label_type = REPLACE_LABEL_INPUT;
317   // How to label return arc.
318   ReplaceLabelType return_label_type = REPLACE_LABEL_NEITHER;
319   // Specifies output label to put on call arc; if kNoLabel, use existing label
320   // on call arc. Otherwise, use this field as the output label.
321   Label call_output_label = kNoLabel;
322   // Specifies label to put on return arc.
323   Label return_label = 0;
324   // Take ownership of input FSTs?
325   bool take_ownership = false;
326   // Pointer to optional pre-constructed state table.
327   StateTable *state_table = nullptr;
328
329   explicit ReplaceFstOptions(const CacheImplOptions<CacheStore> &opts,
330                              Label root = kNoLabel)
331       : CacheImplOptions<CacheStore>(opts), root(root) {}
332
333   explicit ReplaceFstOptions(const CacheOptions &opts, Label root = kNoLabel)
334       : CacheImplOptions<CacheStore>(opts), root(root) {}
335
336   // FIXME(kbg): There are too many constructors here. Come up with a consistent
337   // position for call_output_label (probably the very end) so that it is
338   // possible to express all the remaining constructors with a single
339   // default-argument constructor. Also move clients off of the "backwards
340   // compatibility" constructor, for good.
341
342   explicit ReplaceFstOptions(Label root) : root(root) {}
343
344   explicit ReplaceFstOptions(Label root, ReplaceLabelType call_label_type,
345                              ReplaceLabelType return_label_type,
346                              Label return_label)
347       : root(root),
348         call_label_type(call_label_type),
349         return_label_type(return_label_type),
350         return_label(return_label) {}
351
352   explicit ReplaceFstOptions(Label root, ReplaceLabelType call_label_type,
353                              ReplaceLabelType return_label_type,
354                              Label call_output_label, Label return_label)
355       : root(root),
356         call_label_type(call_label_type),
357         return_label_type(return_label_type),
358         call_output_label(call_output_label),
359         return_label(return_label) {}
360
361   explicit ReplaceFstOptions(const ReplaceUtilOptions &opts)
362       : ReplaceFstOptions(opts.root, opts.call_label_type,
363                           opts.return_label_type, opts.return_label) {}
364
365   ReplaceFstOptions() : root(kNoLabel) {}
366
367   // For backwards compatibility.
368   ReplaceFstOptions(int64 root, bool epsilon_replace_arc)
369       : root(root),
370         call_label_type(epsilon_replace_arc ? REPLACE_LABEL_NEITHER
371                                             : REPLACE_LABEL_INPUT),
372         call_output_label(epsilon_replace_arc ? 0 : kNoLabel) {}
373 };
374
375 // Forward declaration.
376 template <class Arc, class StateTable, class CacheStore>
377 class ReplaceFstMatcher;
378
379 template <class Arc>
380 using FstList = std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>;
381
382 // Returns true if label type on arc results in epsilon input label.
383 inline bool EpsilonOnInput(ReplaceLabelType label_type) {
384   return label_type == REPLACE_LABEL_NEITHER ||
385          label_type == REPLACE_LABEL_OUTPUT;
386 }
387
388 // Returns true if label type on arc results in epsilon input label.
389 inline bool EpsilonOnOutput(ReplaceLabelType label_type) {
390   return label_type == REPLACE_LABEL_NEITHER ||
391          label_type == REPLACE_LABEL_INPUT;
392 }
393
394 // Returns true if for either the call or return arc ilabel != olabel.
395 template <class Label>
396 bool ReplaceTransducer(ReplaceLabelType call_label_type,
397                        ReplaceLabelType return_label_type,
398                        Label call_output_label) {
399   return call_label_type == REPLACE_LABEL_INPUT ||
400          call_label_type == REPLACE_LABEL_OUTPUT ||
401          (call_label_type == REPLACE_LABEL_BOTH &&
402           call_output_label != kNoLabel) ||
403          return_label_type == REPLACE_LABEL_INPUT ||
404          return_label_type == REPLACE_LABEL_OUTPUT;
405 }
406
407 template <class Arc>
408 uint64 ReplaceFstProperties(typename Arc::Label root_label,
409                             const FstList<Arc> &fst_list,
410                             ReplaceLabelType call_label_type,
411                             ReplaceLabelType return_label_type,
412                             typename Arc::Label call_output_label,
413                             bool *sorted_and_non_empty) {
414   using Label = typename Arc::Label;
415   std::vector<uint64> inprops;
416   bool all_ilabel_sorted = true;
417   bool all_olabel_sorted = true;
418   bool all_non_empty = true;
419   // All nonterminals are negative?
420   bool all_negative = true;
421   // All nonterminals are positive and form a dense range containing 1?
422   bool dense_range = true;
423   Label root_fst_idx = 0;
424   for (Label i = 0; i < fst_list.size(); ++i) {
425     const auto label = fst_list[i].first;
426     if (label >= 0) all_negative = false;
427     if (label > fst_list.size() || label <= 0) dense_range = false;
428     if (label == root_label) root_fst_idx = i;
429     const auto *fst = fst_list[i].second;
430     if (fst->Start() == kNoStateId) all_non_empty = false;
431     if (!fst->Properties(kILabelSorted, false)) all_ilabel_sorted = false;
432     if (!fst->Properties(kOLabelSorted, false)) all_olabel_sorted = false;
433     inprops.push_back(fst->Properties(kCopyProperties, false));
434   }
435   const auto props = ReplaceProperties(
436       inprops, root_fst_idx, EpsilonOnInput(call_label_type),
437       EpsilonOnInput(return_label_type), EpsilonOnOutput(call_label_type),
438       EpsilonOnOutput(return_label_type),
439       ReplaceTransducer(call_label_type, return_label_type, call_output_label),
440       all_non_empty, all_ilabel_sorted, all_olabel_sorted,
441       all_negative || dense_range);
442   const bool sorted = props & (kILabelSorted | kOLabelSorted);
443   *sorted_and_non_empty = all_non_empty && sorted;
444   return props;
445 }
446
447 namespace internal {
448
449 // The replace implementation class supports a dynamic expansion of a recursive
450 // transition network represented as label/FST pairs with dynamic replacable
451 // arcs.
452 template <class Arc, class StateTable, class CacheStore>
453 class ReplaceFstImpl
454     : public CacheBaseImpl<typename CacheStore::State, CacheStore> {
455  public:
456   using Label = typename Arc::Label;
457   using StateId = typename Arc::StateId;
458   using Weight = typename Arc::Weight;
459
460   using State = typename CacheStore::State;
461   using CacheImpl = CacheBaseImpl<State, CacheStore>;
462   using PrefixId = typename StateTable::PrefixId;
463   using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
464   using StackPrefix = ReplaceStackPrefix<Label, StateId>;
465   using NonTerminalHash = std::unordered_map<Label, Label>;
466
467   using FstImpl<Arc>::SetType;
468   using FstImpl<Arc>::SetProperties;
469   using FstImpl<Arc>::WriteHeader;
470   using FstImpl<Arc>::SetInputSymbols;
471   using FstImpl<Arc>::SetOutputSymbols;
472   using FstImpl<Arc>::InputSymbols;
473   using FstImpl<Arc>::OutputSymbols;
474
475   using CacheImpl::PushArc;
476   using CacheImpl::HasArcs;
477   using CacheImpl::HasFinal;
478   using CacheImpl::HasStart;
479   using CacheImpl::SetArcs;
480   using CacheImpl::SetFinal;
481   using CacheImpl::SetStart;
482
483   friend class ReplaceFstMatcher<Arc, StateTable, CacheStore>;
484
485   ReplaceFstImpl(const FstList<Arc> &fst_list,
486                  const ReplaceFstOptions<Arc, StateTable, CacheStore> &opts)
487       : CacheImpl(opts),
488         call_label_type_(opts.call_label_type),
489         return_label_type_(opts.return_label_type),
490         call_output_label_(opts.call_output_label),
491         return_label_(opts.return_label),
492         state_table_(opts.state_table ? opts.state_table
493                                       : new StateTable(fst_list, opts.root)) {
494     SetType("replace");
495     // If the label is epsilon, then all replace label options are equivalent,
496     // so we set the label types to NEITHER for simplicity.
497     if (call_output_label_ == 0) call_label_type_ = REPLACE_LABEL_NEITHER;
498     if (return_label_ == 0) return_label_type_ = REPLACE_LABEL_NEITHER;
499     if (!fst_list.empty()) {
500       SetInputSymbols(fst_list[0].second->InputSymbols());
501       SetOutputSymbols(fst_list[0].second->OutputSymbols());
502     }
503     fst_array_.push_back(nullptr);
504     for (Label i = 0; i < fst_list.size(); ++i) {
505       const auto label = fst_list[i].first;
506       const auto *fst = fst_list[i].second;
507       nonterminal_hash_[label] = fst_array_.size();
508       nonterminal_set_.insert(label);
509       fst_array_.emplace_back(opts.take_ownership ? fst : fst->Copy());
510       if (i) {
511         if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) {
512           FSTERROR() << "ReplaceFstImpl: Input symbols of FST " << i
513                      << " do not match input symbols of base FST (0th FST)";
514           SetProperties(kError, kError);
515         }
516         if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) {
517           FSTERROR() << "ReplaceFstImpl: Output symbols of FST " << i
518                      << " do not match output symbols of base FST (0th FST)";
519           SetProperties(kError, kError);
520         }
521       }
522     }
523     const auto nonterminal = nonterminal_hash_[opts.root];
524     if ((nonterminal == 0) && (fst_array_.size() > 1)) {
525       FSTERROR() << "ReplaceFstImpl: No FST corresponding to root label "
526                  << opts.root << " in the input tuple vector";
527       SetProperties(kError, kError);
528     }
529     root_ = (nonterminal > 0) ? nonterminal : 1;
530     bool all_non_empty_and_sorted = false;
531     SetProperties(ReplaceFstProperties(opts.root, fst_list, call_label_type_,
532                                        return_label_type_, call_output_label_,
533                                        &all_non_empty_and_sorted));
534     // Enables optional caching as long as sorted and all non-empty.
535     always_cache_ = !all_non_empty_and_sorted;
536     VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = "
537             << (always_cache_ ? "true" : "false");
538   }
539
540   ReplaceFstImpl(const ReplaceFstImpl &impl)
541       : CacheImpl(impl),
542         call_label_type_(impl.call_label_type_),
543         return_label_type_(impl.return_label_type_),
544         call_output_label_(impl.call_output_label_),
545         return_label_(impl.return_label_),
546         always_cache_(impl.always_cache_),
547         state_table_(new StateTable(*(impl.state_table_))),
548         nonterminal_set_(impl.nonterminal_set_),
549         nonterminal_hash_(impl.nonterminal_hash_),
550         root_(impl.root_) {
551     SetType("replace");
552     SetProperties(impl.Properties(), kCopyProperties);
553     SetInputSymbols(impl.InputSymbols());
554     SetOutputSymbols(impl.OutputSymbols());
555     fst_array_.reserve(impl.fst_array_.size());
556     fst_array_.emplace_back(nullptr);
557     for (Label i = 1; i < impl.fst_array_.size(); ++i) {
558       fst_array_.emplace_back(impl.fst_array_[i]->Copy(true));
559     }
560   }
561
562   // Computes the dependency graph of the replace class and returns
563   // true if the dependencies are cyclic. Cyclic dependencies will result
564   // in an un-expandable FST.
565   bool CyclicDependencies() const {
566     const ReplaceUtilOptions opts(root_);
567     ReplaceUtil<Arc> replace_util(fst_array_, nonterminal_hash_, opts);
568     return replace_util.CyclicDependencies();
569   }
570
571   StateId Start() {
572     if (!HasStart()) {
573       if (fst_array_.size() == 1) {
574         SetStart(kNoStateId);
575         return kNoStateId;
576       } else {
577         const auto fst_start = fst_array_[root_]->Start();
578         if (fst_start == kNoStateId) return kNoStateId;
579         const auto prefix = GetPrefixId(StackPrefix());
580         const auto start =
581             state_table_->FindState(StateTuple(prefix, root_, fst_start));
582         SetStart(start);
583         return start;
584       }
585     } else {
586       return CacheImpl::Start();
587     }
588   }
589
590   Weight Final(StateId s) {
591     if (HasFinal(s)) return CacheImpl::Final(s);
592     const auto &tuple = state_table_->Tuple(s);
593     auto weight = Weight::Zero();
594     if (tuple.prefix_id == 0) {
595       const auto fst_state = tuple.fst_state;
596       weight = fst_array_[tuple.fst_id]->Final(fst_state);
597     }
598     if (always_cache_ || HasArcs(s)) SetFinal(s, weight);
599     return weight;
600   }
601
602   size_t NumArcs(StateId s) {
603     if (HasArcs(s)) {
604       return CacheImpl::NumArcs(s);
605     } else if (always_cache_) {  // If always caching, expands and caches state.
606       Expand(s);
607       return CacheImpl::NumArcs(s);
608     } else {  // Otherwise computes the number of arcs without expanding.
609       const auto tuple = state_table_->Tuple(s);
610       if (tuple.fst_state == kNoStateId) return 0;
611       auto num_arcs = fst_array_[tuple.fst_id]->NumArcs(tuple.fst_state);
612       if (ComputeFinalArc(tuple, nullptr)) ++num_arcs;
613       return num_arcs;
614     }
615   }
616
617   // Returns whether a given label is a non-terminal.
618   bool IsNonTerminal(Label label) const {
619     if (label < *nonterminal_set_.begin() ||
620         label > *nonterminal_set_.rbegin()) {
621       return false;
622     } else {
623       return nonterminal_hash_.count(label);
624     }
625     // TODO(allauzen): be smarter and take advantage of all_dense or
626     // all_negative. Also use this in ComputeArc. This would require changes to
627     // Replace so that recursing into an empty FST lead to a non co-accessible
628     // state instead of deleting the arc as done currently. The current use
629     // correct, since labels are sorted if all_non_empty is true.
630   }
631
632   size_t NumInputEpsilons(StateId s) {
633     if (HasArcs(s)) {
634       return CacheImpl::NumInputEpsilons(s);
635     } else if (always_cache_ || !Properties(kILabelSorted)) {
636       // If always caching or if the number of input epsilons is too expensive
637       // to compute without caching (i.e., not ilabel-sorted), then expands and
638       // caches state.
639       Expand(s);
640       return CacheImpl::NumInputEpsilons(s);
641     } else {
642       // Otherwise, computes the number of input epsilons without caching.
643       const auto tuple = state_table_->Tuple(s);
644       if (tuple.fst_state == kNoStateId) return 0;
645       size_t num = 0;
646       if (!EpsilonOnInput(call_label_type_)) {
647         // If EpsilonOnInput(c) is false, all input epsilon arcs
648         // are also input epsilons arcs in the underlying machine.
649         num = fst_array_[tuple.fst_id]->NumInputEpsilons(tuple.fst_state);
650       } else {
651         // Otherwise, one need to consider that all non-terminal arcs
652         // in the underlying machine also become input epsilon arc.
653         ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
654         for (; !aiter.Done() && ((aiter.Value().ilabel == 0) ||
655                                  IsNonTerminal(aiter.Value().olabel));
656              aiter.Next()) {
657           ++num;
658         }
659       }
660       if (EpsilonOnInput(return_label_type_) &&
661           ComputeFinalArc(tuple, nullptr)) {
662         ++num;
663       }
664       return num;
665     }
666   }
667
668   size_t NumOutputEpsilons(StateId s) {
669     if (HasArcs(s)) {
670       return CacheImpl::NumOutputEpsilons(s);
671     } else if (always_cache_ || !Properties(kOLabelSorted)) {
672       // If always caching or if the number of output epsilons is too expensive
673       // to compute without caching (i.e., not olabel-sorted), then expands and
674       // caches state.
675       Expand(s);
676       return CacheImpl::NumOutputEpsilons(s);
677     } else {
678       // Otherwise, computes the number of output epsilons without caching.
679       const auto tuple = state_table_->Tuple(s);
680       if (tuple.fst_state == kNoStateId) return 0;
681       size_t num = 0;
682       if (!EpsilonOnOutput(call_label_type_)) {
683         // If EpsilonOnOutput(c) is false, all output epsilon arcs are also
684         // output epsilons arcs in the underlying machine.
685         num = fst_array_[tuple.fst_id]->NumOutputEpsilons(tuple.fst_state);
686       } else {
687         // Otherwise, one need to consider that all non-terminal arcs in the
688         // underlying machine also become output epsilon arc.
689         ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
690         for (; !aiter.Done() && ((aiter.Value().olabel == 0) ||
691                                  IsNonTerminal(aiter.Value().olabel));
692              aiter.Next()) {
693           ++num;
694         }
695       }
696       if (EpsilonOnOutput(return_label_type_) &&
697           ComputeFinalArc(tuple, nullptr)) {
698         ++num;
699       }
700       return num;
701     }
702   }
703
704   uint64 Properties() const override { return Properties(kFstProperties); }
705
706   // Sets error if found, and returns other FST impl properties.
707   uint64 Properties(uint64 mask) const override {
708     if (mask & kError) {
709       for (Label i = 1; i < fst_array_.size(); ++i) {
710         if (fst_array_[i]->Properties(kError, false)) {
711           SetProperties(kError, kError);
712         }
713       }
714     }
715     return FstImpl<Arc>::Properties(mask);
716   }
717
718   // Returns the base arc iterator, and if arcs have not been computed yet,
719   // extends and recurses for new arcs.
720   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
721     if (!HasArcs(s)) Expand(s);
722     CacheImpl::InitArcIterator(s, data);
723     // TODO(allauzen): Set behaviour of generic iterator.
724     // Warning: ArcIterator<ReplaceFst<A>>::InitCache() relies on current
725     // behaviour.
726   }
727
728   // Extends current state (walk arcs one level deep).
729   void Expand(StateId s) {
730     const auto tuple = state_table_->Tuple(s);
731     if (tuple.fst_state == kNoStateId) {  // Local FST is empty.
732       SetArcs(s);
733       return;
734     }
735     ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
736     Arc arc;
737     // Creates a final arc when needed.
738     if (ComputeFinalArc(tuple, &arc)) PushArc(s, arc);
739     // Expands all arcs leaving the state.
740     for (; !aiter.Done(); aiter.Next()) {
741       if (ComputeArc(tuple, aiter.Value(), &arc)) PushArc(s, arc);
742     }
743     SetArcs(s);
744   }
745
746   void Expand(StateId s, const StateTuple &tuple,
747               const ArcIteratorData<Arc> &data) {
748     if (tuple.fst_state == kNoStateId) {  // Local FST is empty.
749       SetArcs(s);
750       return;
751     }
752     ArcIterator<Fst<Arc>> aiter(data);
753     Arc arc;
754     // Creates a final arc when needed.
755     if (ComputeFinalArc(tuple, &arc)) AddArc(s, arc);
756     // Expands all arcs leaving the state.
757     for (; !aiter.Done(); aiter.Next()) {
758       if (ComputeArc(tuple, aiter.Value(), &arc)) AddArc(s, arc);
759     }
760     SetArcs(s);
761   }
762
763   // If acpp is null, only returns true if a final arcp is required, but does
764   // not actually compute it.
765   bool ComputeFinalArc(const StateTuple &tuple, Arc *arcp,
766                        uint32 flags = kArcValueFlags) {
767     const auto fst_state = tuple.fst_state;
768     if (fst_state == kNoStateId) return false;
769     // If state is final, pops the stack.
770     if (fst_array_[tuple.fst_id]->Final(fst_state) != Weight::Zero() &&
771         tuple.prefix_id) {
772       if (arcp) {
773         arcp->ilabel = (EpsilonOnInput(return_label_type_)) ? 0 : return_label_;
774         arcp->olabel =
775             (EpsilonOnOutput(return_label_type_)) ? 0 : return_label_;
776         if (flags & kArcNextStateValue) {
777           const auto &stack = state_table_->GetStackPrefix(tuple.prefix_id);
778           const auto prefix_id = PopPrefix(stack);
779           const auto &top = stack.Top();
780           arcp->nextstate = state_table_->FindState(
781               StateTuple(prefix_id, top.fst_id, top.nextstate));
782         }
783         if (flags & kArcWeightValue) {
784           arcp->weight = fst_array_[tuple.fst_id]->Final(fst_state);
785         }
786       }
787       return true;
788     } else {
789       return false;
790     }
791   }
792
793   // Computes an arc in the FST corresponding to one in the underlying machine.
794   // Returns false if the underlying arc corresponds to no arc in the resulting
795   // FST.
796   bool ComputeArc(const StateTuple &tuple, const Arc &arc, Arc *arcp,
797                   uint32 flags = kArcValueFlags) {
798     if (!EpsilonOnInput(call_label_type_) &&
799         (flags == (flags & (kArcILabelValue | kArcWeightValue)))) {
800       *arcp = arc;
801       return true;
802     }
803     if (arc.olabel == 0 || arc.olabel < *nonterminal_set_.begin() ||
804         arc.olabel > *nonterminal_set_.rbegin()) {  // Expands local FST.
805       const auto nextstate =
806           flags & kArcNextStateValue
807               ? state_table_->FindState(
808                     StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
809               : kNoStateId;
810       *arcp = Arc(arc.ilabel, arc.olabel, arc.weight, nextstate);
811     } else {
812       // Checks for non-terminal.
813       const auto it = nonterminal_hash_.find(arc.olabel);
814       if (it != nonterminal_hash_.end()) {  // Recurses into non-terminal.
815         const auto nonterminal = it->second;
816         const auto nt_prefix =
817             PushPrefix(state_table_->GetStackPrefix(tuple.prefix_id),
818                        tuple.fst_id, arc.nextstate);
819         // If the start state is valid, replace; othewise, the arc is implicitly
820         // deleted.
821         const auto nt_start = fst_array_[nonterminal]->Start();
822         if (nt_start != kNoStateId) {
823           const auto nt_nextstate = flags & kArcNextStateValue
824                                         ? state_table_->FindState(StateTuple(
825                                               nt_prefix, nonterminal, nt_start))
826                                         : kNoStateId;
827           const auto ilabel =
828               (EpsilonOnInput(call_label_type_)) ? 0 : arc.ilabel;
829           const auto olabel =
830               (EpsilonOnOutput(call_label_type_))
831                   ? 0
832                   : ((call_output_label_ == kNoLabel) ? arc.olabel
833                                                       : call_output_label_);
834           *arcp = Arc(ilabel, olabel, arc.weight, nt_nextstate);
835         } else {
836           return false;
837         }
838       } else {
839         const auto nextstate =
840             flags & kArcNextStateValue
841                 ? state_table_->FindState(
842                       StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
843                 : kNoStateId;
844         *arcp = Arc(arc.ilabel, arc.olabel, arc.weight, nextstate);
845       }
846     }
847     return true;
848   }
849
850   // Returns the arc iterator flags supported by this FST.
851   uint32 ArcIteratorFlags() const {
852     uint32 flags = kArcValueFlags;
853     if (!always_cache_) flags |= kArcNoCache;
854     return flags;
855   }
856
857   StateTable *GetStateTable() const { return state_table_.get(); }
858
859   const Fst<Arc> *GetFst(Label fst_id) const {
860     return fst_array_[fst_id].get();
861   }
862
863   Label GetFstId(Label nonterminal) const {
864     const auto it = nonterminal_hash_.find(nonterminal);
865     if (it == nonterminal_hash_.end()) {
866       FSTERROR() << "ReplaceFstImpl::GetFstId: Nonterminal not found: "
867                  << nonterminal;
868     }
869     return it->second;
870   }
871
872   // Returns true if label type on call arc results in epsilon input label.
873   bool EpsilonOnCallInput() { return EpsilonOnInput(call_label_type_); }
874
875  private:
876   // The unique index into stack prefix table.
877   PrefixId GetPrefixId(const StackPrefix &prefix) {
878     return state_table_->FindPrefixId(prefix);
879   }
880
881   // The prefix ID after a stack pop.
882   PrefixId PopPrefix(StackPrefix prefix) {
883     prefix.Pop();
884     return GetPrefixId(prefix);
885   }
886
887   // The prefix ID after a stack push.
888   PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
889     prefix.Push(fst_id, nextstate);
890     return GetPrefixId(prefix);
891   }
892
893   // Runtime options
894   ReplaceLabelType call_label_type_;    // How to label call arc.
895   ReplaceLabelType return_label_type_;  // How to label return arc.
896   int64 call_output_label_;  // Specifies output label to put on call arc
897   int64 return_label_;       // Specifies label to put on return arc.
898   bool always_cache_;        // Disable optional caching of arc iterator?
899
900   // State table.
901   std::unique_ptr<StateTable> state_table_;
902
903   // Replace components.
904   std::set<Label> nonterminal_set_;
905   NonTerminalHash nonterminal_hash_;
906   std::vector<std::unique_ptr<const Fst<Arc>>> fst_array_;
907   Label root_;
908 };
909
910 }  // namespace internal
911
912 //
913 // ReplaceFst supports dynamic replacement of arcs in one FST with another FST.
914 // This replacement is recursive. ReplaceFst can be used to support a variety of
915 // delayed constructions such as recursive
916 // transition networks, union, or closure. It is constructed with an array of
917 // FST(s). One FST represents the root (or topology) machine. The root FST
918 // refers to other FSTs by recursively replacing arcs labeled as non-terminals
919 // with the matching non-terminal FST. Currently the ReplaceFst uses the output
920 // symbols of the arcs to determine whether the arc is a non-terminal arc or
921 // not. A non-terminal can be any label that is not a non-zero terminal label in
922 // the output alphabet.
923 //
924 // Note that the constructor uses a vector of pairs. These correspond to the
925 // tuple of non-terminal Label and corresponding FST. For example to implement
926 // the closure operation we need 2 FSTs. The first root FST is a single
927 // self-loop arc on the start state.
928 //
929 // The ReplaceFst class supports an optionally caching arc iterator.
930 //
931 // The ReplaceFst needs to be built such that it is known to be ilabel- or
932 // olabel-sorted (see usage below).
933 //
934 // Observe that Matcher<Fst<A>> will use the optionally caching arc iterator
935 // when available (the FST is ilabel-sorted and matching on the input, or the
936 // FST is olabel -orted and matching on the output).  In order to obtain the
937 // most efficient behaviour, it is recommended to set call_label_type to
938 // REPLACE_LABEL_INPUT or REPLACE_LABEL_BOTH and return_label_type to
939 // REPLACE_LABEL_OUTPUT or REPLACE_LABEL_NEITHER. This means that the call arc
940 // does not have epsilon on the input side and the return arc has epsilon on the
941 // input side) and matching on the input side.
942 //
943 // This class attaches interface to implementation and handles reference
944 // counting, delegating most methods to ImplToFst.
945 template <class A, class T /* = DefaultReplaceStateTable<A> */,
946           class CacheStore /* = DefaultCacheStore<A> */>
947 class ReplaceFst
948     : public ImplToFst<internal::ReplaceFstImpl<A, T, CacheStore>> {
949  public:
950   using Arc = A;
951   using Label = typename Arc::Label;
952   using StateId = typename Arc::StateId;
953   using Weight = typename Arc::Weight;
954
955   using StateTable = T;
956   using Store = CacheStore;
957   using State = typename CacheStore::State;
958   using Impl = internal::ReplaceFstImpl<Arc, StateTable, CacheStore>;
959   using CacheImpl = internal::CacheBaseImpl<State, CacheStore>;
960
961   using ImplToFst<Impl>::Properties;
962
963   friend class ArcIterator<ReplaceFst<Arc, StateTable, CacheStore>>;
964   friend class StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>;
965   friend class ReplaceFstMatcher<Arc, StateTable, CacheStore>;
966
967   ReplaceFst(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_array,
968              Label root)
969       : ImplToFst<Impl>(std::make_shared<Impl>(
970             fst_array, ReplaceFstOptions<Arc, StateTable, CacheStore>(root))) {}
971
972   ReplaceFst(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_array,
973              const ReplaceFstOptions<Arc, StateTable, CacheStore> &opts)
974       : ImplToFst<Impl>(std::make_shared<Impl>(fst_array, opts)) {}
975
976   // See Fst<>::Copy() for doc.
977   ReplaceFst(const ReplaceFst<Arc, StateTable, CacheStore> &fst,
978              bool safe = false)
979       : ImplToFst<Impl>(fst, safe) {}
980
981   // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc.
982   ReplaceFst<Arc, StateTable, CacheStore> *Copy(
983       bool safe = false) const override {
984     return new ReplaceFst<Arc, StateTable, CacheStore>(*this, safe);
985   }
986
987   inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
988
989   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
990     GetMutableImpl()->InitArcIterator(s, data);
991   }
992
993   MatcherBase<Arc> *InitMatcher(MatchType match_type) const override {
994     if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
995         ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) ||
996          (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) {
997       return new ReplaceFstMatcher<Arc, StateTable, CacheStore>(*this,
998                                                                 match_type);
999     } else {
1000       VLOG(2) << "Not using replace matcher";
1001       return nullptr;
1002     }
1003   }
1004
1005   bool CyclicDependencies() const { return GetImpl()->CyclicDependencies(); }
1006
1007   const StateTable &GetStateTable() const {
1008     return *GetImpl()->GetStateTable();
1009   }
1010
1011   const Fst<Arc> &GetFst(Label nonterminal) const {
1012     return *GetImpl()->GetFst(GetImpl()->GetFstId(nonterminal));
1013   }
1014
1015  private:
1016   using ImplToFst<Impl>::GetImpl;
1017   using ImplToFst<Impl>::GetMutableImpl;
1018
1019   ReplaceFst &operator=(const ReplaceFst &) = delete;
1020 };
1021
1022 // Specialization for ReplaceFst.
1023 template <class Arc, class StateTable, class CacheStore>
1024 class StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>
1025     : public CacheStateIterator<ReplaceFst<Arc, StateTable, CacheStore>> {
1026  public:
1027   explicit StateIterator(const ReplaceFst<Arc, StateTable, CacheStore> &fst)
1028       : CacheStateIterator<ReplaceFst<Arc, StateTable, CacheStore>>(
1029             fst, fst.GetMutableImpl()) {}
1030 };
1031
1032 // Specialization for ReplaceFst, implementing optional caching. It is be used
1033 // as follows:
1034 //
1035 //   ReplaceFst<A> replace;
1036 //   ArcIterator<ReplaceFst<A>> aiter(replace, s);
1037 //   // Note: ArcIterator< Fst<A>> is always a caching arc iterator.
1038 //   aiter.SetFlags(kArcNoCache, kArcNoCache);
1039 //   // Uses the arc iterator, no arc will be cached, no state will be expanded.
1040 //   // Arc flags can be used to decide which component of the arc need to be
1041 //   computed.
1042 //   aiter.SetFlags(kArcILabelValue, kArcValueFlags);
1043 //   // Wants the ilabel for this arc.
1044 //   aiter.Value();  // Does not compute the destination state.
1045 //   aiter.Next();
1046 //   aiter.SetFlags(kArcNextStateValue, kArcNextStateValue);
1047 //   // Wants the ilabel and next state for this arc.
1048 //   aiter.Value();  // Does compute the destination state and inserts it
1049 //                   // in the replace state table.
1050 //   // No additional arcs have been cached at this point.
1051 template <class Arc, class StateTable, class CacheStore>
1052 class ArcIterator<ReplaceFst<Arc, StateTable, CacheStore>> {
1053  public:
1054   using StateId = typename Arc::StateId;
1055
1056   using StateTuple = typename StateTable::StateTuple;
1057
1058   ArcIterator(const ReplaceFst<Arc, StateTable, CacheStore> &fst, StateId s)
1059       : fst_(fst),
1060         s_(s),
1061         pos_(0),
1062         offset_(0),
1063         flags_(kArcValueFlags),
1064         arcs_(nullptr),
1065         data_flags_(0),
1066         final_flags_(0) {
1067     cache_data_.ref_count = nullptr;
1068     local_data_.ref_count = nullptr;
1069     // If FST does not support optional caching, forces caching.
1070     if (!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
1071         !(fst_.GetImpl()->HasArcs(s_))) {
1072       fst_.GetMutableImpl()->Expand(s_);
1073     }
1074     // If state is already cached, use cached arcs array.
1075     if (fst_.GetImpl()->HasArcs(s_)) {
1076       (fst_.GetImpl())
1077           ->internal::template CacheBaseImpl<
1078               typename CacheStore::State,
1079               CacheStore>::InitArcIterator(s_, &cache_data_);
1080       num_arcs_ = cache_data_.narcs;
1081       arcs_ = cache_data_.arcs;      // arcs_ is a pointer to the cached arcs.
1082       data_flags_ = kArcValueFlags;  // All the arc member values are valid.
1083     } else {  // Otherwise delay decision until Value() is called.
1084       tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(s_);
1085       if (tuple_.fst_state == kNoStateId) {
1086         num_arcs_ = 0;
1087       } else {
1088         // The decision to cache or not to cache has been defered until Value()
1089         // or
1090         // SetFlags() is called. However, the arc iterator is set up now to be
1091         // ready for non-caching in order to keep the Value() method simple and
1092         // efficient.
1093         const auto *rfst = fst_.GetImpl()->GetFst(tuple_.fst_id);
1094         rfst->InitArcIterator(tuple_.fst_state, &local_data_);
1095         // arcs_ is a pointer to the arcs in the underlying machine.
1096         arcs_ = local_data_.arcs;
1097         // Computes the final arc (but not its destination state) if a final arc
1098         // is required.
1099         bool has_final_arc = fst_.GetMutableImpl()->ComputeFinalArc(
1100             tuple_, &final_arc_, kArcValueFlags & ~kArcNextStateValue);
1101         // Sets the arc value flags that hold for final_arc_.
1102         final_flags_ = kArcValueFlags & ~kArcNextStateValue;
1103         // Computes the number of arcs.
1104         num_arcs_ = local_data_.narcs;
1105         if (has_final_arc) ++num_arcs_;
1106         // Sets the offset between the underlying arc positions and the
1107         // positions
1108         // in the arc iterator.
1109         offset_ = num_arcs_ - local_data_.narcs;
1110         // Defers the decision to cache or not until Value() or SetFlags() is
1111         // called.
1112         data_flags_ = 0;
1113       }
1114     }
1115   }
1116
1117   ~ArcIterator() {
1118     if (cache_data_.ref_count) --(*cache_data_.ref_count);
1119     if (local_data_.ref_count) --(*local_data_.ref_count);
1120   }
1121
1122   void ExpandAndCache() const  {
1123     // TODO(allauzen): revisit this.
1124     // fst_.GetImpl()->Expand(s_, tuple_, local_data_);
1125     // (fst_.GetImpl())->CacheImpl<A>*>::InitArcIterator(s_,
1126     //                                               &cache_data_);
1127     //
1128     fst_.InitArcIterator(s_, &cache_data_);  // Expand and cache state.
1129     arcs_ = cache_data_.arcs;      // arcs_ is a pointer to the cached arcs.
1130     data_flags_ = kArcValueFlags;  // All the arc member values are valid.
1131     offset_ = 0;                   // No offset.
1132   }
1133
1134   void Init() {
1135     if (flags_ & kArcNoCache) {  // If caching is disabled
1136       // arcs_ is a pointer to the arcs in the underlying machine.
1137       arcs_ = local_data_.arcs;
1138       // Sets the arcs value flags that hold for arcs_.
1139       data_flags_ = kArcWeightValue;
1140       if (!fst_.GetMutableImpl()->EpsilonOnCallInput()) {
1141         data_flags_ |= kArcILabelValue;
1142       }
1143       // Sets the offset between the underlying arc positions and the positions
1144       // in the arc iterator.
1145       offset_ = num_arcs_ - local_data_.narcs;
1146     } else {
1147       ExpandAndCache();
1148     }
1149   }
1150
1151   bool Done() const { return pos_ >= num_arcs_; }
1152
1153   const Arc &Value() const {
1154     // If data_flags_ is 0, non-caching was not requested.
1155     if (!data_flags_) {
1156       // TODO(allauzen): Revisit this.
1157       if (flags_ & kArcNoCache) {
1158         // Should never happen.
1159         FSTERROR() << "ReplaceFst: Inconsistent arc iterator flags";
1160       }
1161       ExpandAndCache();
1162     }
1163     if (pos_ - offset_ >= 0) {  // The requested arc is not the final arc.
1164       const auto &arc = arcs_[pos_ - offset_];
1165       if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) {
1166         // If the value flags match the recquired value flags then returns the
1167         // arc.
1168         return arc;
1169       } else {
1170         // Otherwise, compute the corresponding arc on-the-fly.
1171         fst_.GetMutableImpl()->ComputeArc(tuple_, arc, &arc_,
1172                                           flags_ & kArcValueFlags);
1173         return arc_;
1174       }
1175     } else {  // The requested arc is the final arc.
1176       if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) {
1177         // If the arc value flags that hold for the final arc do not match the
1178         // requested value flags, then
1179         // final_arc_ needs to be updated.
1180         fst_.GetMutableImpl()->ComputeFinalArc(tuple_, &final_arc_,
1181                                                flags_ & kArcValueFlags);
1182         final_flags_ = flags_ & kArcValueFlags;
1183       }
1184       return final_arc_;
1185     }
1186   }
1187
1188   void Next() { ++pos_; }
1189
1190   size_t Position() const { return pos_; }
1191
1192   void Reset() { pos_ = 0; }
1193
1194   void Seek(size_t pos) { pos_ = pos; }
1195
1196   uint32 Flags() const { return flags_; }
1197
1198   void SetFlags(uint32 flags, uint32 mask) {
1199     // Updates the flags taking into account what flags are supported
1200     // by the FST.
1201     flags_ &= ~mask;
1202     flags_ |= (flags & fst_.GetImpl()->ArcIteratorFlags());
1203     // If non-caching is not requested (and caching has not already been
1204     // performed), then flush data_flags_ to request caching during the next
1205     // call to Value().
1206     if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) {
1207       if (!fst_.GetImpl()->HasArcs(s_)) data_flags_ = 0;
1208     }
1209     // If data_flags_ has been flushed but non-caching is requested before
1210     // calling Value(), then set up the iterator for non-caching.
1211     if ((flags & kArcNoCache) && (!data_flags_)) Init();
1212   }
1213
1214  private:
1215   const ReplaceFst<Arc, StateTable, CacheStore> &fst_;  // Reference to the FST.
1216   StateId s_;                                           // State in the FST.
1217   mutable StateTuple tuple_;  // Tuple corresponding to state_.
1218
1219   ssize_t pos_;             // Current position.
1220   mutable ssize_t offset_;  // Offset between position in iterator and in arcs_.
1221   ssize_t num_arcs_;        // Number of arcs at state_.
1222   uint32 flags_;            // Behavorial flags for the arc iterator
1223   mutable Arc arc_;         // Memory to temporarily store computed arcs.
1224
1225   mutable ArcIteratorData<Arc> cache_data_;  // Arc iterator data in cache.
1226   mutable ArcIteratorData<Arc> local_data_;  // Arc iterator data in local FST.
1227
1228   mutable const Arc *arcs_;     // Array of arcs.
1229   mutable uint32 data_flags_;   // Arc value flags valid for data in arcs_.
1230   mutable Arc final_arc_;       // Final arc (when required).
1231   mutable uint32 final_flags_;  // Arc value flags valid for final_arc_.
1232
1233   ArcIterator(const ArcIterator &) = delete;
1234   ArcIterator &operator=(const ArcIterator &) = delete;
1235 };
1236
1237 template <class Arc, class StateTable, class CacheStore>
1238 class ReplaceFstMatcher : public MatcherBase<Arc> {
1239  public:
1240   using Label = typename Arc::Label;
1241   using StateId = typename Arc::StateId;
1242   using Weight = typename Arc::Weight;
1243
1244   using FST = ReplaceFst<Arc, StateTable, CacheStore>;
1245   using LocalMatcher = MultiEpsMatcher<Matcher<Fst<Arc>>>;
1246
1247   using StateTuple = typename StateTable::StateTuple;
1248
1249   ReplaceFstMatcher(const ReplaceFst<Arc, StateTable, CacheStore> &fst,
1250                     MatchType match_type)
1251       : fst_(fst),
1252         impl_(fst_.GetMutableImpl()),
1253         s_(fst::kNoStateId),
1254         match_type_(match_type),
1255         current_loop_(false),
1256         final_arc_(false),
1257         loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
1258     if (match_type_ == fst::MATCH_OUTPUT) {
1259       std::swap(loop_.ilabel, loop_.olabel);
1260     }
1261     InitMatchers();
1262   }
1263
1264   ReplaceFstMatcher(
1265       const ReplaceFstMatcher<Arc, StateTable, CacheStore> &matcher,
1266       bool safe = false)
1267       : fst_(matcher.fst_),
1268         impl_(fst_.GetMutableImpl()),
1269         s_(fst::kNoStateId),
1270         match_type_(matcher.match_type_),
1271         current_loop_(false),
1272         final_arc_(false),
1273         loop_(fst::kNoLabel, 0, Weight::One(), fst::kNoStateId) {
1274     if (match_type_ == fst::MATCH_OUTPUT) {
1275       std::swap(loop_.ilabel, loop_.olabel);
1276     }
1277     InitMatchers();
1278   }
1279
1280   // Creates a local matcher for each component FST in the RTN. LocalMatcher is
1281   // a multi-epsilon wrapper matcher. MultiEpsilonMatcher is used to match each
1282   // non-terminal arc, since these non-terminal
1283   // turn into epsilons on recursion.
1284   void InitMatchers() {
1285     const auto &fst_array = impl_->fst_array_;
1286     matcher_.resize(fst_array.size());
1287     for (Label i = 0; i < fst_array.size(); ++i) {
1288       if (fst_array[i]) {
1289         matcher_[i].reset(
1290             new LocalMatcher(*fst_array[i], match_type_, kMultiEpsList));
1291         auto it = impl_->nonterminal_set_.begin();
1292         for (; it != impl_->nonterminal_set_.end(); ++it) {
1293           matcher_[i]->AddMultiEpsLabel(*it);
1294         }
1295       }
1296     }
1297   }
1298
1299   ReplaceFstMatcher<Arc, StateTable, CacheStore> *Copy(
1300       bool safe = false) const override {
1301     return new ReplaceFstMatcher<Arc, StateTable, CacheStore>(*this, safe);
1302   }
1303
1304   MatchType Type(bool test) const override {
1305     if (match_type_ == MATCH_NONE) return match_type_;
1306     const auto true_prop =
1307         match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted;
1308     const auto false_prop =
1309         match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted;
1310     const auto props = fst_.Properties(true_prop | false_prop, test);
1311     if (props & true_prop) {
1312       return match_type_;
1313     } else if (props & false_prop) {
1314       return MATCH_NONE;
1315     } else {
1316       return MATCH_UNKNOWN;
1317     }
1318   }
1319
1320   const Fst<Arc> &GetFst() const override { return fst_; }
1321
1322   uint64 Properties(uint64 props) const override { return props; }
1323
1324   // Sets the state from which our matching happens.
1325   void SetState(StateId s) final {
1326     if (s_ == s) return;
1327     s_ = s;
1328     tuple_ = impl_->GetStateTable()->Tuple(s_);
1329     if (tuple_.fst_state == kNoStateId) {
1330       done_ = true;
1331       return;
1332     }
1333     // Gets current matcher, used for non-epsilon matching.
1334     current_matcher_ = matcher_[tuple_.fst_id].get();
1335     current_matcher_->SetState(tuple_.fst_state);
1336     loop_.nextstate = s_;
1337     final_arc_ = false;
1338   }
1339
1340   // Searches for label from previous set state. If label == 0, first
1341   // hallucinate an epsilon loop; otherwise use the underlying matcher to
1342   // search for the label or epsilons. Note since the ReplaceFst recursion
1343   // on non-terminal arcs causes epsilon transitions to be created we use
1344   // MultiEpsilonMatcher to search for possible matches of non-terminals. If the
1345   // component FST
1346   // reaches a final state we also need to add the exiting final arc.
1347   bool Find(Label label) final {
1348     bool found = false;
1349     label_ = label;
1350     if (label_ == 0 || label_ == kNoLabel) {
1351       // Computes loop directly, avoiding Replace::ComputeArc.
1352       if (label_ == 0) {
1353         current_loop_ = true;
1354         found = true;
1355       }
1356       // Searches for matching multi-epsilons.
1357       final_arc_ = impl_->ComputeFinalArc(tuple_, nullptr);
1358       found = current_matcher_->Find(kNoLabel) || final_arc_ || found;
1359     } else {
1360       // Searches on a sub machine directly using sub machine matcher.
1361       found = current_matcher_->Find(label_);
1362     }
1363     return found;
1364   }
1365
1366   bool Done() const final {
1367     return !current_loop_ && !final_arc_ && current_matcher_->Done();
1368   }
1369
1370   const Arc &Value() const final {
1371     if (current_loop_) return loop_;
1372     if (final_arc_) {
1373       impl_->ComputeFinalArc(tuple_, &arc_);
1374       return arc_;
1375     }
1376     const auto &component_arc = current_matcher_->Value();
1377     impl_->ComputeArc(tuple_, component_arc, &arc_);
1378     return arc_;
1379   }
1380
1381   void Next() final {
1382     if (current_loop_) {
1383       current_loop_ = false;
1384       return;
1385     }
1386     if (final_arc_) {
1387       final_arc_ = false;
1388       return;
1389     }
1390     current_matcher_->Next();
1391   }
1392
1393   ssize_t Priority(StateId s) final { return fst_.NumArcs(s); }
1394
1395  private:
1396   const ReplaceFst<Arc, StateTable, CacheStore> &fst_;
1397   internal::ReplaceFstImpl<Arc, StateTable, CacheStore> *impl_;
1398   LocalMatcher *current_matcher_;
1399   std::vector<std::unique_ptr<LocalMatcher>> matcher_;
1400   StateId s_;             // Current state.
1401   Label label_;           // Current label.
1402   MatchType match_type_;  // Supplied by caller.
1403   mutable bool done_;
1404   mutable bool current_loop_;  // Current arc is the implicit loop.
1405   mutable bool final_arc_;     // Current arc for exiting recursion.
1406   mutable StateTuple tuple_;   // Tuple corresponding to state_.
1407   mutable Arc arc_;
1408   Arc loop_;
1409
1410   ReplaceFstMatcher &operator=(const ReplaceFstMatcher &) = delete;
1411 };
1412
1413 template <class Arc, class StateTable, class CacheStore>
1414 inline void ReplaceFst<Arc, StateTable, CacheStore>::InitStateIterator(
1415     StateIteratorData<Arc> *data) const {
1416   data->base =
1417       new StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>(*this);
1418 }
1419
1420 using StdReplaceFst = ReplaceFst<StdArc>;
1421
1422 // Recursively replaces arcs in the root FSTs with other FSTs.
1423 // This version writes the result of replacement to an output MutableFst.
1424 //
1425 // Replace supports replacement of arcs in one Fst with another FST. This
1426 // replacement is recursive. Replace takes an array of FST(s). One FST
1427 // represents the root (or topology) machine. The root FST refers to other FSTs
1428 // by recursively replacing arcs labeled as non-terminals with the matching
1429 // non-terminal FST. Currently Replace uses the output symbols of the arcs to
1430 // determine whether the arc is a non-terminal arc or not. A non-terminal can be
1431 // any label that is not a non-zero terminal label in the output alphabet.
1432 //
1433 // Note that input argument is a vector of pairs. These correspond to the tuple
1434 // of non-terminal Label and corresponding FST.
1435 template <class Arc>
1436 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1437                  &ifst_array,
1438              MutableFst<Arc> *ofst,
1439              ReplaceFstOptions<Arc> opts = ReplaceFstOptions<Arc>()) {
1440   opts.gc = true;
1441   opts.gc_limit = 0;  // Caches only the last state for fastest copy.
1442   *ofst = ReplaceFst<Arc>(ifst_array, opts);
1443 }
1444
1445 template <class Arc>
1446 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1447                  &ifst_array,
1448              MutableFst<Arc> *ofst, const ReplaceUtilOptions &opts) {
1449   Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(opts));
1450 }
1451
1452 // For backwards compatibility.
1453 template <class Arc>
1454 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1455                  &ifst_array,
1456              MutableFst<Arc> *ofst, typename Arc::Label root,
1457              bool epsilon_on_replace) {
1458   Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root, epsilon_on_replace));
1459 }
1460
1461 template <class Arc>
1462 void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
1463                  &ifst_array,
1464              MutableFst<Arc> *ofst, typename Arc::Label root) {
1465   Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root));
1466 }
1467
1468 }  // namespace fst
1469
1470 #endif  // FST_REPLACE_H_