a4863fd6f42e083faacc0937b585a7cd79898345
[platform/upstream/openfst.git] / src / include / fst / compose.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Class to compute the composition of two FSTs.
5
6 #ifndef FST_COMPOSE_H_
7 #define FST_COMPOSE_H_
8
9 #include <algorithm>
10
11 #include <fst/log.h>
12
13 #include <fst/cache.h>
14 #include <fst/compose-filter.h>
15 #include <fst/fst-decl.h>  // For optional argument declarations
16 #include <fst/lookahead-filter.h>
17 #include <fst/matcher.h>
18 #include <fst/state-table.h>
19 #include <fst/test-properties.h>
20
21
22 namespace fst {
23
24 // Delayed composition options templated on the arc type, the matcher,
25 // the composition filter, and the composition state table.  By
26 // default, the matchers, filter, and state table are constructed by
27 // composition. If set below, the user can instead pass in these
28 // objects; in that case, ComposeFst takes their ownership. This
29 // version controls composition implemented between generic Fst<Arc>
30 // types and a shared matcher type M for Fst<Arc>. This should be
31 // adequate for most applications, giving a reasonable tradeoff
32 // between efficiency and code sharing (but see ComposeFstImplOptions).
33 template <class Arc, class M = Matcher<Fst<Arc>>,
34           class Filter = SequenceComposeFilter<M>,
35           class StateTable =
36               GenericComposeStateTable<Arc, typename Filter::FilterState>>
37 struct ComposeFstOptions : public CacheOptions {
38   M *matcher1;              // FST1 matcher.
39   M *matcher2;              // FST2 matcher.
40   Filter *filter;           // Composition filter.
41   StateTable *state_table;  // Composition state table.
42
43   explicit ComposeFstOptions(const CacheOptions &opts = CacheOptions(),
44                              M *matcher1 = nullptr, M *matcher2 = nullptr,
45                              Filter *filter = nullptr,
46                              StateTable *state_table = nullptr)
47       : CacheOptions(opts),
48         matcher1(matcher1),
49         matcher2(matcher2),
50         filter(filter),
51         state_table(state_table) {}
52 };
53
54 // Forward declaration of ComposeFstMatcher.
55 template <class C, class F, class T>
56 class ComposeFstMatcher;
57
58 // Delayed composition options templated on the two matcher types, the
59 // composition filter, the composition state table and the cache store. By
60 // default, the matchers, filter, state table and cache store are constructed
61 // by composition. If set below, the user can instead pass in these objects; in
62 // that case, ComposeFst takes their ownership. This version controls
63 // composition implemented using arbitrary matchers (of the same arc type but
64 // otherwise arbitrary FST type). The user must ensure the matchers are
65 // compatible. These options permit the most efficient use, but shares the
66 // least code. This is for advanced use only in the most demanding or
67 // specialized applications that can benefit from it; otherwise, prefer
68 // ComposeFstOptions).
69 template <class M1, class M2, class Filter = SequenceComposeFilter<M1, M2>,
70           class StateTable = GenericComposeStateTable<
71               typename M1::Arc, typename Filter::FilterState>,
72           class CacheStore = DefaultCacheStore<typename M1::Arc>>
73 struct ComposeFstImplOptions : public CacheImplOptions<CacheStore> {
74   M1 *matcher1;    // FST1 matcher (see matcher.h)....
75   M2 *matcher2;    // FST2 matcher.
76   Filter *filter;  // Composition filter (see compose-filter.h).
77   StateTable
78     *state_table;        // Composition state table (see compose-state-table.h).
79   bool own_state_table;   // ComposeFstImpl takes ownership of 'state_table'?
80   bool allow_noncommute;  // Allow non-commutative weights
81
82   explicit ComposeFstImplOptions(const CacheOptions &opts,
83                                  M1 *matcher1 = nullptr, M2 *matcher2 = nullptr,
84                                  Filter *filter = nullptr,
85                                  StateTable *state_table = nullptr)
86       : CacheImplOptions<CacheStore>(opts),
87         matcher1(matcher1),
88         matcher2(matcher2),
89         filter(filter),
90         state_table(state_table),
91         own_state_table(true),
92         allow_noncommute(false) {}
93
94   explicit ComposeFstImplOptions(const CacheImplOptions<CacheStore> &opts,
95                                  M1 *matcher1 = nullptr, M2 *matcher2 = nullptr,
96                                  Filter *filter = nullptr,
97                                  StateTable *state_table = nullptr)
98       : CacheImplOptions<CacheStore>(opts),
99         matcher1(matcher1),
100         matcher2(matcher2),
101         filter(filter),
102         state_table(state_table),
103         own_state_table(true),
104         allow_noncommute(false) {}
105
106   ComposeFstImplOptions()
107       : matcher1(nullptr),
108         matcher2(nullptr),
109         filter(nullptr),
110         state_table(nullptr),
111         own_state_table(true),
112         allow_noncommute(false) {}
113 };
114
115 namespace internal {
116
117 // Implementation of delayed composition. This base class is common to the
118 // variants with different matchers, composition filters and state tables.
119 template <class Arc, class CacheStore = DefaultCacheStore<Arc>,
120           class F = ComposeFst<Arc, CacheStore>>
121 class ComposeFstImplBase
122     : public CacheBaseImpl<typename CacheStore::State, CacheStore> {
123  public:
124   using FST = F;
125   using Label = typename Arc::Label;
126   using StateId = typename Arc::StateId;
127   using Weight = typename Arc::Weight;
128
129   using State = typename CacheStore::State;
130   using CacheImpl = CacheBaseImpl<State, CacheStore>;
131
132   using FstImpl<Arc>::SetType;
133   using FstImpl<Arc>::SetProperties;
134   using FstImpl<Arc>::Properties;
135   using FstImpl<Arc>::SetInputSymbols;
136   using FstImpl<Arc>::SetOutputSymbols;
137
138   using CacheImpl::HasStart;
139   using CacheImpl::HasFinal;
140   using CacheImpl::HasArcs;
141   using CacheImpl::SetFinal;
142   using CacheImpl::SetStart;
143
144   ComposeFstImplBase(const CacheImplOptions<CacheStore> &opts)
145       : CacheImpl(opts) {}
146
147   ComposeFstImplBase(const CacheOptions &opts) : CacheImpl(opts) {}
148
149   ComposeFstImplBase(const ComposeFstImplBase &impl) : CacheImpl(impl, true) {
150     SetType(impl.Type());
151     SetProperties(impl.Properties(), kCopyProperties);
152     SetInputSymbols(impl.InputSymbols());
153     SetOutputSymbols(impl.OutputSymbols());
154   }
155
156   virtual ComposeFstImplBase *Copy() const = 0;
157
158   ~ComposeFstImplBase() override {}
159
160   StateId Start() {
161     if (!HasStart()) {
162       const auto start = ComputeStart();
163       if (start != kNoStateId) SetStart(start);
164     }
165     return CacheImpl::Start();
166   }
167
168   Weight Final(StateId s) {
169     if (!HasFinal(s)) SetFinal(s, ComputeFinal(s));
170     return CacheImpl::Final(s);
171   }
172
173   virtual void Expand(StateId s) = 0;
174
175   size_t NumArcs(StateId s) {
176     if (!HasArcs(s)) Expand(s);
177     return CacheImpl::NumArcs(s);
178   }
179
180   size_t NumInputEpsilons(StateId s) {
181     if (!HasArcs(s)) Expand(s);
182     return CacheImpl::NumInputEpsilons(s);
183   }
184
185   size_t NumOutputEpsilons(StateId s) {
186     if (!HasArcs(s)) Expand(s);
187     return CacheImpl::NumOutputEpsilons(s);
188   }
189
190   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
191     if (!HasArcs(s)) Expand(s);
192     CacheImpl::InitArcIterator(s, data);
193   }
194
195   virtual MatcherBase<Arc> *InitMatcher(const F &fst,
196                                         MatchType match_type) const {
197     // Use the default matcher if no override is provided.
198     return nullptr;
199   }
200
201  protected:
202   virtual StateId ComputeStart() = 0;
203   virtual Weight ComputeFinal(StateId s) = 0;
204 };
205
206 // Implementation of delayed composition templated on the matchers (see
207 // matcher.h), composition filter (see compose-filter.h) and the composition
208 // state table (see compose-state-table.h).
209 template <class CacheStore, class Filter, class StateTable>
210 class ComposeFstImpl
211     : public ComposeFstImplBase<typename CacheStore::Arc, CacheStore> {
212  public:
213   using Matcher1 = typename Filter::Matcher1;
214   using Matcher2 = typename Filter::Matcher2;
215
216   using FST1 = typename Matcher1::FST;
217   using FST2 = typename Matcher2::FST;
218
219   using Arc = typename CacheStore::Arc;
220   using Label = typename Arc::Label;
221   using StateId = typename Arc::StateId;
222   using Weight = typename Arc::Weight;
223
224   using FilterState = typename Filter::FilterState;
225   using State = typename CacheStore::State;
226
227   using CacheImpl = CacheBaseImpl<State, CacheStore>;
228
229   using StateTuple = typename StateTable::StateTuple;
230
231   friend class ComposeFstMatcher<CacheStore, Filter, StateTable>;
232
233   using FstImpl<Arc>::SetInputSymbols;
234   using FstImpl<Arc>::SetOutputSymbols;
235   using FstImpl<Arc>::SetType;
236   using FstImpl<Arc>::SetProperties;
237
238   template <class M1, class M2>
239   ComposeFstImpl(const FST1 &fst1, const FST2 &fst2,
240                  const ComposeFstImplOptions<M1, M2, Filter, StateTable,
241                                              CacheStore> &opts);
242
243   ComposeFstImpl(const ComposeFstImpl &impl)
244       : ComposeFstImplBase<Arc, CacheStore>(impl),
245         filter_(new Filter(*impl.filter_, true)),
246         matcher1_(filter_->GetMatcher1()),
247         matcher2_(filter_->GetMatcher2()),
248         fst1_(matcher1_->GetFst()),
249         fst2_(matcher2_->GetFst()),
250         state_table_(new StateTable(*impl.state_table_)),
251         own_state_table_(true),
252         match_type_(impl.match_type_) {}
253
254   ~ComposeFstImpl() override {
255     if (own_state_table_) delete state_table_;
256   }
257
258   ComposeFstImpl *Copy() const override { return new ComposeFstImpl(*this); }
259
260   uint64 Properties() const override { return Properties(kFstProperties); }
261
262   // Sets error if found, and returns other FST impl properties.
263   uint64 Properties(uint64 mask) const override {
264     if ((mask & kError) &&
265         (fst1_.Properties(kError, false) || fst2_.Properties(kError, false) ||
266          (matcher1_->Properties(0) & kError) ||
267          (matcher2_->Properties(0) & kError) |
268              (filter_->Properties(0) & kError) ||
269          state_table_->Error())) {
270       SetProperties(kError, kError);
271     }
272     return FstImpl<Arc>::Properties(mask);
273   }
274
275   // Arranges it so that the first arg to OrderedExpand is the Fst
276   // that will be matched on.
277   void Expand(StateId s) override {
278     const auto &tuple = state_table_->Tuple(s);
279     const auto s1 = tuple.StateId1();
280     const auto s2 = tuple.StateId2();
281     filter_->SetState(s1, s2, tuple.GetFilterState());
282     if (MatchInput(s1, s2)) {
283       OrderedExpand(s, fst2_, s2, fst1_, s1, matcher2_, true);
284     } else {
285       OrderedExpand(s, fst1_, s1, fst2_, s2, matcher1_, false);
286     }
287   }
288
289   const FST1 &GetFst1() const { return fst1_; }
290
291   const FST2 &GetFst2() const { return fst2_; }
292
293   const Matcher1 *GetMatcher1() const { return matcher1_; }
294
295   Matcher1 *GetMatcher1() { return matcher1_; }
296
297   const Matcher2 *GetMatcher2() const { return matcher2_; }
298
299   Matcher2 *GetMatcher2() { return matcher2_; }
300
301   const Filter *GetFilter() const { return filter_.get(); }
302
303   Filter *GetFilter() { return filter_.get(); }
304
305   const StateTable *GetStateTable() const { return state_table_; }
306
307   StateTable *GetStateTable() { return state_table_; }
308
309   MatcherBase<Arc> *InitMatcher(const ComposeFst<Arc, CacheStore> &fst,
310                                 MatchType match_type) const override {
311     const auto test_props = match_type == MATCH_INPUT
312                                 ? kFstProperties & ~kILabelInvariantProperties
313                                 : kFstProperties & ~kOLabelInvariantProperties;
314     // If both matchers support 'match_type' and we have a guarantee that a
315     // call to 'filter_->FilterArc(arc1, arc2)' will not modify the ilabel of
316     // arc1 when MATCH_INPUT or the olabel or arc2 when MATCH_OUTPUT, then
317     // ComposeFstMatcher can be used.
318     if ((matcher1_->Type(false) == match_type) &&
319         (matcher2_->Type(false) == match_type) &&
320         (filter_->Properties(test_props) == test_props)) {
321       return new ComposeFstMatcher<CacheStore, Filter, StateTable>(fst, this,
322                                                                    match_type);
323     }
324     return nullptr;
325   }
326
327  private:
328   // This does that actual matching of labels in the composition. The
329   // arguments are ordered so matching is called on state 'sa' of
330   // 'fsta' for each arc leaving state 'sb' of 'fstb'. The 'match_input' arg
331   // determines whether the input or output label of arcs at 'sb' is
332   // the one to match on.
333   template <class FST, class Matcher>
334   void OrderedExpand(StateId s, const Fst<Arc> &, StateId sa, const FST &fstb,
335                      StateId sb, Matcher *matchera, bool match_input) {
336     matchera->SetState(sa);
337     // First processes non-consuming symbols (e.g., epsilons) on FSTA.
338     const Arc loop(match_input ? 0 : kNoLabel, match_input ? kNoLabel : 0,
339                    Weight::One(), sb);
340     MatchArc(s, matchera, loop, match_input);
341     // Then processes matches on FSTB.
342     for (ArcIterator<FST> iterb(fstb, sb); !iterb.Done(); iterb.Next()) {
343       MatchArc(s, matchera, iterb.Value(), match_input);
344     }
345     CacheImpl::SetArcs(s);
346   }
347
348   // Matches a single transition from 'fstb' against 'fata' at 's'.
349   template <class Matcher>
350   void MatchArc(StateId s, Matcher *matchera, const Arc &arc,
351                 bool match_input) {
352     if (matchera->Find(match_input ? arc.olabel : arc.ilabel)) {
353       for (; !matchera->Done(); matchera->Next()) {
354         auto arca = matchera->Value();
355         auto arcb = arc;
356         if (match_input) {
357           const auto &fs = filter_->FilterArc(&arcb, &arca);
358           if (fs != FilterState::NoState()) AddArc(s, arcb, arca, fs);
359         } else {
360           const auto &fs = filter_->FilterArc(&arca, &arcb);
361           if (fs != FilterState::NoState()) AddArc(s, arca, arcb, fs);
362         }
363       }
364     }
365   }
366
367   // Add a matching transition at 's'.
368   void AddArc(StateId s, const Arc &arc1, const Arc &arc2,
369               const FilterState &f) {
370     const StateTuple tuple(arc1.nextstate, arc2.nextstate, f);
371     const Arc oarc(arc1.ilabel, arc2.olabel, Times(arc1.weight, arc2.weight),
372                    state_table_->FindState(tuple));
373     CacheImpl::PushArc(s, oarc);
374   }
375
376   StateId ComputeStart() override {
377     const auto s1 = fst1_.Start();
378     if (s1 == kNoStateId) return kNoStateId;
379     const auto s2 = fst2_.Start();
380     if (s2 == kNoStateId) return kNoStateId;
381     const auto &fs = filter_->Start();
382     const StateTuple tuple(s1, s2, fs);
383     return state_table_->FindState(tuple);
384   }
385
386   Weight ComputeFinal(StateId s) override {
387     const auto &tuple = state_table_->Tuple(s);
388     const auto s1 = tuple.StateId1();
389     auto final1 = matcher1_->Final(s1);
390     if (final1 == Weight::Zero()) return final1;
391     const auto s2 = tuple.StateId2();
392     auto final2 = matcher2_->Final(s2);
393     if (final2 == Weight::Zero()) return final2;
394     filter_->SetState(s1, s2, tuple.GetFilterState());
395     filter_->FilterFinal(&final1, &final2);
396     return Times(final1, final2);
397   }
398
399   // Determines which side to match on per composition state.
400   bool MatchInput(StateId s1, StateId s2) {
401     switch (match_type_) {
402       case MATCH_INPUT:
403         return true;
404       case MATCH_OUTPUT:
405         return false;
406       default:  // MATCH_BOTH
407         const auto priority1 = matcher1_->Priority(s1);
408         const auto priority2 = matcher2_->Priority(s2);
409         if (priority1 == kRequirePriority && priority2 == kRequirePriority) {
410           FSTERROR() << "ComposeFst: Both sides can't require match";
411           SetProperties(kError, kError);
412           return true;
413         }
414         if (priority1 == kRequirePriority) return false;
415         if (priority2 == kRequirePriority) {
416           return true;
417         }
418         return priority1 <= priority2;
419     }
420   }
421
422   // Identifies and verifies the capabilities of the matcher to be used for
423   // composition.
424   void SetMatchType();
425
426   std::unique_ptr<Filter> filter_;
427   Matcher1 *matcher1_;  // Borrowed reference.
428   Matcher2 *matcher2_;  // Borrowed reference.
429   const FST1 &fst1_;
430   const FST2 &fst2_;
431   StateTable *state_table_;
432   bool own_state_table_;
433
434   MatchType match_type_;
435 };
436
437 template <class CacheStore, class Filter, class StateTable>
438 template <class M1, class M2>
439 ComposeFstImpl<CacheStore, Filter, StateTable>::ComposeFstImpl(
440     const FST1 &fst1, const FST2 &fst2,
441     const ComposeFstImplOptions<M1, M2, Filter, StateTable, CacheStore> &opts)
442     : ComposeFstImplBase<Arc, CacheStore>(opts),
443       filter_(opts.filter
444                   ? opts.filter
445                   : new Filter(fst1, fst2, opts.matcher1, opts.matcher2)),
446       matcher1_(filter_->GetMatcher1()),
447       matcher2_(filter_->GetMatcher2()),
448       fst1_(matcher1_->GetFst()),
449       fst2_(matcher2_->GetFst()),
450       state_table_(opts.state_table ? opts.state_table
451                                     : new StateTable(fst1_, fst2_)),
452       own_state_table_(opts.state_table ? opts.own_state_table : true) {
453   SetType("compose");
454   if (!CompatSymbols(fst2.InputSymbols(), fst1.OutputSymbols())) {
455     FSTERROR() << "ComposeFst: Output symbol table of 1st argument "
456                << "does not match input symbol table of 2nd argument";
457     SetProperties(kError, kError);
458   }
459   SetInputSymbols(fst1_.InputSymbols());
460   SetOutputSymbols(fst2_.OutputSymbols());
461   SetMatchType();
462   VLOG(2) << "ComposeFstImpl: Match type: " << match_type_;
463   if (match_type_ == MATCH_NONE) SetProperties(kError, kError);
464   const auto fprops1 = fst1.Properties(kFstProperties, false);
465   const auto fprops2 = fst2.Properties(kFstProperties, false);
466   const auto mprops1 = matcher1_->Properties(fprops1);
467   const auto mprops2 = matcher2_->Properties(fprops2);
468   const auto cprops = ComposeProperties(mprops1, mprops2);
469   SetProperties(filter_->Properties(cprops), kCopyProperties);
470   if (state_table_->Error()) SetProperties(kError, kError);
471 }
472
473 template <class CacheStore, class Filter, class StateTable>
474 void ComposeFstImpl<CacheStore, Filter, StateTable>::SetMatchType() {
475   // Ensures any required matching is possible and known.
476   if ((matcher1_->Flags() & kRequireMatch) &&
477       matcher1_->Type(true) != MATCH_OUTPUT) {
478     FSTERROR() << "ComposeFst: 1st argument cannot perform required matching "
479                << "(sort?).";
480     match_type_ = MATCH_NONE;
481     return;
482   }
483   if ((matcher2_->Flags() & kRequireMatch) &&
484       matcher2_->Type(true) != MATCH_INPUT) {
485     FSTERROR() << "ComposeFst: 2nd argument cannot perform required matching "
486                << "(sort?).";
487     match_type_ = MATCH_NONE;
488     return;
489   }
490   // Finds which sides to match on (favoring minimal testing of capabilities).
491   const auto type1 = matcher1_->Type(false);
492   const auto type2 = matcher2_->Type(false);
493   if (type1 == MATCH_OUTPUT && type2 == MATCH_INPUT) {
494     match_type_ = MATCH_BOTH;
495   } else if (type1 == MATCH_OUTPUT) {
496     match_type_ = MATCH_OUTPUT;
497   } else if (type2 == MATCH_INPUT) {
498     match_type_ = MATCH_INPUT;
499   } else if (matcher1_->Type(true) == MATCH_OUTPUT) {
500     match_type_ = MATCH_OUTPUT;
501   } else if (matcher2_->Type(true) == MATCH_INPUT) {
502     match_type_ = MATCH_INPUT;
503   } else {
504     FSTERROR() << "ComposeFst: 1st argument cannot match on output labels "
505                << "and 2nd argument cannot match on input labels (sort?).";
506     match_type_ = MATCH_NONE;
507   }
508 }
509
510 }  // namespace internal
511
512 // Computes the composition of two transducers. This version is a delayed FST.
513 // If FST1 transduces string x to y with weight a and FST2 transduces y to z
514 // with weight b, then their composition transduces string x to z with weight
515 // Times(x, z).
516 //
517 // The output labels of the first transducer or the input labels of the second
518 // transducer must be sorted (with the default matcher). The weights need to
519 // form a commutative semiring (valid for TropicalWeight and LogWeight).
520 //
521 // Complexity:
522 //
523 // Assuming the first FST is unsorted and the second is sorted,
524 //
525 //   Time: O(v1 v2 d1 (log d2 + m2)),
526 //   Space: O(v1 v2)
527 //
528 // where vi = # of states visited, di = maximum out-degree, and mi the
529 // maximum multiplicity of the states visited, for the ith FST. Constant time
530 // and space to visit an input state or arc is assumed and exclusive of caching.
531 //
532 // Caveats:
533 // - ComposeFst does not trim its output (since it is a delayed operation).
534 // - The efficiency of composition can be strongly affected by several factors:
535 //   - the choice of which transducer is sorted - prefer sorting the FST
536 //     that has the greater average out-degree.
537 //   - the amount of non-determinism
538 //   - the presence and location of epsilon transitions - avoid epsilon
539 //     transitions on the output side of the first transducer or
540 //     the input side of the second transducer or prefer placing
541 //     them later in a path since they delay matching and can
542 //     introduce non-coaccessible states and transitions.
543 //
544 // This class attaches interface to implementation and handles reference
545 // counting, delegating most methods to ImplToFst. The CacheStore specifies the
546 // cache store (default declared in fst-decl.h).
547 template <class A, class CacheStore /* = DefaultCacheStore<A> */>
548 class ComposeFst
549     : public ImplToFst<internal::ComposeFstImplBase<A, CacheStore>> {
550  public:
551   using Arc = A;
552   using StateId = typename Arc::StateId;
553   using Weight = typename Arc::Weight;
554
555   using Store = CacheStore;
556   using State = typename CacheStore::State;
557
558   using Impl = internal::ComposeFstImplBase<A, CacheStore>;
559
560   friend class ArcIterator<ComposeFst<Arc, CacheStore>>;
561   friend class StateIterator<ComposeFst<Arc, CacheStore>>;
562
563   // Compose specifying only caching options.
564   ComposeFst(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
565              const CacheOptions &opts = CacheOptions())
566       : ImplToFst<Impl>(CreateBase(fst1, fst2, opts)) {}
567
568   // Compose specifying one shared matcher type M. Requires that the input FSTs
569   // and matcher FST types be Fst<Arc>. Recommended for best code-sharing and
570   // matcher compatiblity.
571   template <class Matcher, class Filter, class StateTuple>
572   ComposeFst(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
573              const ComposeFstOptions<Arc, Matcher, Filter, StateTuple> &opts)
574       : ImplToFst<Impl>(CreateBase1(fst1, fst2, opts)) {}
575
576   // Compose specifying two matcher types Matcher1 and Matcher2. Requires input
577   // FST (of the same Arc type, but o.w. arbitrary) match the corresponding
578   // matcher FST types). Recommended only for advanced use in demanding or
579   // specialized applications due to potential code bloat and matcher
580   // incompatibilities.
581   template <class Matcher1, class Matcher2, class Filter, class StateTuple>
582   ComposeFst(const typename Matcher1::FST &fst1,
583              const typename Matcher2::FST &fst2,
584              const ComposeFstImplOptions<Matcher1, Matcher2, Filter, StateTuple,
585                                          CacheStore> &opts)
586       : ImplToFst<Impl>(CreateBase2(fst1, fst2, opts)) {}
587
588   // See Fst<>::Copy() for doc.
589   ComposeFst(const ComposeFst<A, CacheStore> &fst, bool safe = false)
590       : ImplToFst<Impl>(safe ? std::shared_ptr<Impl>(fst.GetImpl()->Copy())
591                              : fst.GetSharedImpl()) {}
592
593   // Get a copy of this ComposeFst. See Fst<>::Copy() for further doc.
594   ComposeFst<A, CacheStore> *Copy(bool safe = false) const override {
595     return new ComposeFst<A, CacheStore>(*this, safe);
596   }
597
598   inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
599
600   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
601     GetMutableImpl()->InitArcIterator(s, data);
602   }
603
604   MatcherBase<Arc> *InitMatcher(MatchType match_type) const override {
605     return GetImpl()->InitMatcher(*this, match_type);
606   }
607
608  protected:
609   using ImplToFst<Impl>::GetImpl;
610   using ImplToFst<Impl>::GetMutableImpl;
611
612   explicit ComposeFst(std::shared_ptr<Impl> impl) : ImplToFst<Impl>(impl) {}
613
614   // Create compose implementation specifying two matcher types.
615   template <class Matcher1, class Matcher2, class Filter, class StateTuple>
616   static std::shared_ptr<Impl> CreateBase2(
617       const typename Matcher1::FST &fst1, const typename Matcher2::FST &fst2,
618       const ComposeFstImplOptions<Matcher1, Matcher2, Filter, StateTuple,
619                                   CacheStore> &opts) {
620     auto impl = std::make_shared<
621         internal::ComposeFstImpl<CacheStore, Filter, StateTuple>>(fst1, fst2,
622                                                                   opts);
623     if (!(Weight::Properties() & kCommutative) && !opts.allow_noncommute) {
624       const auto props1 = fst1.Properties(kUnweighted, true);
625       const auto props2 = fst2.Properties(kUnweighted, true);
626       if (!(props1 & kUnweighted) && !(props2 & kUnweighted)) {
627         FSTERROR() << "ComposeFst: Weights must be a commutative semiring: "
628                    << Weight::Type();
629         impl->SetProperties(kError, kError);
630       }
631     }
632     return impl;
633   }
634
635   // Create compose implementation specifying one matcher type; requires that
636   // input and matcher FST types be Fst<Arc>.
637   template <class Matcher, class Filter, class StateTuple>
638   static std::shared_ptr<Impl> CreateBase1(
639       const Fst<Arc> &fst1, const Fst<Arc> &fst2,
640       const ComposeFstOptions<Arc, Matcher, Filter, StateTuple> &opts) {
641     ComposeFstImplOptions<Matcher, Matcher, Filter, StateTuple, CacheStore>
642         nopts(opts, opts.matcher1, opts.matcher2, opts.filter,
643               opts.state_table);
644     return CreateBase2(fst1, fst2, nopts);
645   }
646
647   // Create compose implementation specifying no matcher type.
648   static std::shared_ptr<Impl> CreateBase(const Fst<Arc> &fst1,
649                                           const Fst<Arc> &fst2,
650                                           const CacheOptions &opts) {
651     switch (LookAheadMatchType(fst1, fst2)) {  // Check for lookahead matchers
652       default:
653       case MATCH_NONE: {  // Default composition (no look-ahead).
654         ComposeFstOptions<Arc> nopts(opts);
655         return CreateBase1(fst1, fst2, nopts);
656       }
657       case MATCH_OUTPUT: {  // Lookahead on fst1.
658         using M = typename DefaultLookAhead<Arc, MATCH_OUTPUT>::FstMatcher;
659         using F = typename DefaultLookAhead<Arc, MATCH_OUTPUT>::ComposeFilter;
660         ComposeFstOptions<Arc, M, F> nopts(opts);
661         return CreateBase1(fst1, fst2, nopts);
662       }
663       case MATCH_INPUT: {  // Lookahead on fst2
664         using M = typename DefaultLookAhead<Arc, MATCH_INPUT>::FstMatcher;
665         using F = typename DefaultLookAhead<Arc, MATCH_INPUT>::ComposeFilter;
666         ComposeFstOptions<Arc, M, F> nopts(opts);
667         return CreateBase1(fst1, fst2, nopts);
668       }
669     }
670   }
671
672  private:
673   ComposeFst &operator=(const ComposeFst &fst) = delete;
674 };
675
676 // Specialization for ComposeFst.
677 template <class Arc, class CacheStore>
678 class StateIterator<ComposeFst<Arc, CacheStore>>
679     : public CacheStateIterator<ComposeFst<Arc, CacheStore>> {
680  public:
681   explicit StateIterator(const ComposeFst<Arc, CacheStore> &fst)
682       : CacheStateIterator<ComposeFst<Arc, CacheStore>>(fst,
683                                                         fst.GetMutableImpl()) {}
684 };
685
686 // Specialization for ComposeFst.
687 template <class Arc, class CacheStore>
688 class ArcIterator<ComposeFst<Arc, CacheStore>>
689     : public CacheArcIterator<ComposeFst<Arc, CacheStore>> {
690  public:
691   using StateId = typename Arc::StateId;
692
693   ArcIterator(const ComposeFst<Arc, CacheStore> &fst, StateId s)
694       : CacheArcIterator<ComposeFst<Arc, CacheStore>>(fst.GetMutableImpl(), s) {
695     if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
696   }
697 };
698
699 template <class Arc, class CacheStore>
700 inline void ComposeFst<Arc, CacheStore>::InitStateIterator(
701     StateIteratorData<Arc> *data) const {
702   data->base = new StateIterator<ComposeFst<Arc, CacheStore>>(*this);
703 }
704
705 // Specialized matcher for ComposeFst. Supports MATCH_INPUT or MATCH_OUTPUT,
706 // iff the underlying matchers for the two FSTS being composed support
707 // MATCH_INPUT or MATCH_OUTPUT, respectively.
708 template <class CacheStore, class Filter, class StateTable>
709 class ComposeFstMatcher : public MatcherBase<typename CacheStore::Arc> {
710  public:
711   using Arc = typename CacheStore::Arc;
712   using Label = typename Arc::Label;
713   using StateId = typename Arc::StateId;
714   using Weight = typename Arc::Weight;
715
716   using Matcher1 = typename Filter::Matcher1;
717   using Matcher2 = typename Filter::Matcher2;
718   using FilterState = typename Filter::FilterState;
719
720   using StateTuple = typename StateTable::StateTuple;
721
722   ComposeFstMatcher(
723       const ComposeFst<Arc, CacheStore> &fst,
724       const internal::ComposeFstImpl<CacheStore, Filter, StateTable> *impl,
725       MatchType match_type)
726       : fst_(fst),
727         impl_(impl),
728         s_(kNoStateId),
729         match_type_(match_type),
730         matcher1_(impl->matcher1_->Copy()),
731         matcher2_(impl->matcher2_->Copy()),
732         current_loop_(false),
733         loop_(kNoLabel, 0, Weight::One(), kNoStateId),
734         error_(false) {
735     if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
736   }
737
738   ComposeFstMatcher(
739       const ComposeFstMatcher<CacheStore, Filter, StateTable> &matcher,
740       bool safe = false)
741       : fst_(matcher.fst_),
742         impl_(matcher.impl_),
743         s_(kNoStateId),
744         match_type_(matcher.match_type_),
745         matcher1_(matcher.matcher1_->Copy(safe)),
746         matcher2_(matcher.matcher2_->Copy(safe)),
747         current_loop_(false),
748         loop_(kNoLabel, 0, Weight::One(), kNoStateId),
749         error_(matcher.error_) {
750     if (safe == true) {
751       FSTERROR() << "ComposeFstMatcher: Safe copy not supported";
752       error_ = true;
753     }
754     if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
755   }
756
757   ComposeFstMatcher<CacheStore, Filter, StateTable> *Copy(
758       bool safe = false) const override {
759     return new ComposeFstMatcher<CacheStore, Filter, StateTable>(*this, safe);
760   }
761
762   MatchType Type(bool test) const override {
763     if ((matcher1_->Type(test) == MATCH_NONE) ||
764         (matcher2_->Type(test) == MATCH_NONE)) {
765       return MATCH_NONE;
766     }
767     if (((matcher1_->Type(test) == MATCH_UNKNOWN) &&
768          (matcher2_->Type(test) == MATCH_UNKNOWN)) ||
769         ((matcher1_->Type(test) == MATCH_UNKNOWN) &&
770          (matcher2_->Type(test) == match_type_)) ||
771         ((matcher1_->Type(test) == match_type_) &&
772          (matcher2_->Type(test) == MATCH_UNKNOWN))) {
773       return MATCH_UNKNOWN;
774     }
775     if ((matcher1_->Type(test) == match_type_) &&
776         (matcher2_->Type(test) == match_type_)) {
777       return match_type_;
778     }
779     return MATCH_NONE;
780   }
781
782   const Fst<Arc> &GetFst() const override { return fst_; }
783
784   uint64 Properties(uint64 inprops) const override {
785     auto outprops = inprops;
786     if (error_) outprops |= kError;
787     return outprops;
788   }
789
790   void SetState(StateId s) final {
791     if (s_ == s) return;
792     s_ = s;
793     const auto &tuple = impl_->state_table_->Tuple(s);
794     matcher1_->SetState(tuple.StateId1());
795     matcher2_->SetState(tuple.StateId2());
796     loop_.nextstate = s_;
797   }
798
799   bool Find(Label label) final {
800     bool found = false;
801     current_loop_ = false;
802     if (label == 0) {
803       current_loop_ = true;
804       found = true;
805     }
806     if (match_type_ == MATCH_INPUT) {
807       found = found || FindLabel(label, matcher1_.get(), matcher2_.get());
808     } else {  // match_type_ == MATCH_OUTPUT
809       found = found || FindLabel(label, matcher2_.get(), matcher1_.get());
810     }
811     return found;
812   }
813
814   bool Done() const final {
815     return !current_loop_ && matcher1_->Done() && matcher2_->Done();
816   }
817
818   const Arc &Value() const final { return current_loop_ ? loop_ : arc_; }
819
820   void Next() final {
821     if (current_loop_) {
822       current_loop_ = false;
823     } else if (match_type_ == MATCH_INPUT) {
824       FindNext(matcher1_.get(), matcher2_.get());
825     } else {  // match_type_ == MATCH_OUTPUT
826       FindNext(matcher2_.get(), matcher1_.get());
827     }
828   }
829
830   ssize_t Priority(StateId s) final { return fst_.NumArcs(s); }
831
832  private:
833   // Processes a match with the filter and creates resulting arc.
834   bool MatchArc(StateId s, Arc arc1,
835                 Arc arc2) {  // FIXME(kbg): copy but not assignment.
836     const auto &fs = impl_->filter_->FilterArc(&arc1, &arc2);
837     if (fs == FilterState::NoState()) return false;
838     const StateTuple tuple(arc1.nextstate, arc2.nextstate, fs);
839     arc_.ilabel = arc1.ilabel;
840     arc_.olabel = arc2.olabel;
841     arc_.weight = Times(arc1.weight, arc2.weight);
842     arc_.nextstate = impl_->state_table_->FindState(tuple);
843     return true;
844   }
845
846   // Finds the first match allowed by the filter.
847   template <class MatcherA, class MatcherB>
848   bool FindLabel(Label label, MatcherA *matchera, MatcherB *matcherb) {
849     if (matchera->Find(label)) {
850       matcherb->Find(match_type_ == MATCH_INPUT ? matchera->Value().olabel
851                                                 : matchera->Value().ilabel);
852       return FindNext(matchera, matcherb);
853     }
854     return false;
855   }
856
857   // Finds the next match allowed by the filter, returning true iff such a
858   // match is found.
859   template <class MatcherA, class MatcherB>
860   bool FindNext(MatcherA *matchera, MatcherB *matcherb) {
861     // State when entering this function:
862     // 'matchera' is pointed to a match x, y for label x, and a match for y was
863     // requested on 'matcherb'.
864     while (!matchera->Done() || !matcherb->Done()) {
865       if (matcherb->Done()) {
866         // If no more matches for y on 'matcherb', moves forward on 'matchera'
867         // until a match x, y' is found such that there is a match for y' on
868         // 'matcherb'.
869         matchera->Next();
870         while (!matchera->Done() &&
871                !matcherb->Find(match_type_ == MATCH_INPUT
872                                    ? matchera->Value().olabel
873                                    : matchera->Value().ilabel)) {
874           matchera->Next();
875         }
876       }
877       while (!matcherb->Done()) {
878         // 'matchera' is pointing to a match x, y' ('arca') and 'matcherb' is
879         // pointing to a match y', z' ('arcb'). If combining these two arcs is
880         // allowed by the filter (hence resulting in an arc x, z') return true.
881         // Position 'matcherb' on the next potential match for y' before
882         // returning.
883         const auto &arca = matchera->Value();
884         const auto &arcb = matcherb->Value();
885         // Position 'matcherb' on the next potential match for y'.
886         matcherb->Next();
887         // Returns true If combining these two arcs is allowed by the filter
888         // (hence resulting in an arc x, z'); otherwise consider next match
889         // for y' on 'matcherb'.
890         if (MatchArc(s_, match_type_ == MATCH_INPUT ? arca : arcb,
891                      match_type_ == MATCH_INPUT ? arcb : arca)) {
892           return true;
893         }
894       }
895     }
896     // Both 'matchera' and 'matcherb' are done, no more match to analyse.
897     return false;
898   }
899   const ComposeFst<Arc, CacheStore> &fst_;
900   const internal::ComposeFstImpl<CacheStore, Filter, StateTable> *impl_;
901   StateId s_;
902   MatchType match_type_;
903   std::unique_ptr<Matcher1> matcher1_;
904   std::unique_ptr<Matcher2> matcher2_;
905   bool current_loop_;
906   Arc loop_;
907   Arc arc_;
908   bool error_;
909 };
910
911 // Useful alias when using StdArc.
912 using StdComposeFst = ComposeFst<StdArc>;
913
914 enum ComposeFilter {
915   AUTO_FILTER,
916   NULL_FILTER,
917   TRIVIAL_FILTER,
918   SEQUENCE_FILTER,
919   ALT_SEQUENCE_FILTER,
920   MATCH_FILTER
921 };
922
923 struct ComposeOptions {
924   bool connect;               // Connect output?
925   ComposeFilter filter_type;  // Pre-defined filter to use.
926
927   explicit ComposeOptions(bool connect = true,
928                           ComposeFilter filter_type = AUTO_FILTER)
929       : connect(connect), filter_type(filter_type) {}
930 };
931
932 // Computes the composition of two transducers. This version writes
933 // the composed FST into a MutableFst. If FST1 transduces string x to
934 // y with weight a and FST2 transduces y to z with weight b, then
935 // their composition transduces string x to z with weight
936 // Times(x, z).
937 //
938 // The output labels of the first transducer or the input labels of
939 // the second transducer must be sorted.  The weights need to form a
940 // commutative semiring (valid for TropicalWeight and LogWeight).
941 //
942 // Complexity:
943 //
944 // Assuming the first FST is unsorted and the second is sorted:
945 //
946 //   Time: O(V1 V2 D1 (log D2 + M2)),
947 //   Space: O(V1 V2 D1 M2)
948 //
949 // where Vi = # of states, Di = maximum out-degree, and Mi is the maximum
950 // multiplicity, for the ith FST.
951 //
952 // Caveats:
953 //
954 // - Compose trims its output.
955 // - The efficiency of composition can be strongly affected by several factors:
956 //   - the choice of which transducer is sorted - prefer sorting the FST
957 //     that has the greater average out-degree.
958 //   - the amount of non-determinism
959 //   - the presence and location of epsilon transitions - avoid epsilon
960 //     transitions on the output side of the first transducer or
961 //     the input side of the second transducer or prefer placing
962 //     them later in a path since they delay matching and can
963 //     introduce non-coaccessible states and transitions.
964 template <class Arc>
965 void Compose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
966              MutableFst<Arc> *ofst,
967              const ComposeOptions &opts = ComposeOptions()) {
968   using M = Matcher<Fst<Arc>>;
969   // In each case, we cache only the last state for fastest copy.
970   switch (opts.filter_type) {
971     case AUTO_FILTER: {
972       CacheOptions nopts;
973       nopts.gc_limit = 0;
974       *ofst = ComposeFst<Arc>(ifst1, ifst2, nopts);
975       break;
976     }
977     case NULL_FILTER: {
978       ComposeFstOptions<Arc, M, NullComposeFilter<M>> copts;
979       copts.gc_limit = 0;
980       *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
981       break;
982     }
983     case SEQUENCE_FILTER: {
984       ComposeFstOptions<Arc, M, SequenceComposeFilter<M>> copts;
985       copts.gc_limit = 0;
986       *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
987       break;
988     }
989     case ALT_SEQUENCE_FILTER: {
990       ComposeFstOptions<Arc, M, AltSequenceComposeFilter<M>> copts;
991       copts.gc_limit = 0;
992       *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
993       break;
994     }
995     case MATCH_FILTER: {
996       ComposeFstOptions<Arc, M, MatchComposeFilter<M>> copts;
997       copts.gc_limit = 0;
998       *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
999       break;
1000     }
1001     case TRIVIAL_FILTER: {
1002       ComposeFstOptions<Arc, M, TrivialComposeFilter<M>> copts;
1003       copts.gc_limit = 0;
1004       *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
1005       break;
1006     }
1007   }
1008   if (opts.connect) Connect(ofst);
1009 }
1010
1011 }  // namespace fst
1012
1013 #endif  // FST_COMPOSE_H_