2a3c5d7082665bd529cf329a34bf054dbc87d887
[platform/upstream/openfst.git] / src / test / algo_test.h
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Regression test for various FST algorithms.
5
6 #ifndef FST_TEST_ALGO_TEST_H_
7 #define FST_TEST_ALGO_TEST_H_
8
9 #include <fst/log.h>
10
11 #include <fst/fstlib.h>
12 #include "./rand-fst.h"
13
14 DECLARE_int32(repeat);  // defined in ./algo_test.cc
15
16 namespace fst {
17
18 // Mapper to change input and output label of every transition into
19 // epsilons.
20 template <class A>
21 class EpsMapper {
22  public:
23   EpsMapper() {}
24
25   A operator()(const A &arc) const {
26     return A(0, 0, arc.weight, arc.nextstate);
27   }
28
29   uint64 Properties(uint64 props) const {
30     props &= ~kNotAcceptor;
31     props |= kAcceptor;
32     props &= ~kNoIEpsilons & ~kNoOEpsilons & ~kNoEpsilons;
33     props |= kIEpsilons | kOEpsilons | kEpsilons;
34     props &= ~kNotILabelSorted & ~kNotOLabelSorted;
35     props |= kILabelSorted | kOLabelSorted;
36     return props;
37   }
38
39   MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
40
41   MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
42
43   MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
44 };
45
46 // Generic - no lookahead.
47 template <class Arc>
48 void LookAheadCompose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
49                       MutableFst<Arc> *ofst) {
50   Compose(ifst1, ifst2, ofst);
51 }
52
53 // Specialized and epsilon olabel acyclic - lookahead.
54 void LookAheadCompose(const Fst<StdArc> &ifst1, const Fst<StdArc> &ifst2,
55                       MutableFst<StdArc> *ofst) {
56   std::vector<StdArc::StateId> order;
57   bool acyclic;
58   TopOrderVisitor<StdArc> visitor(&order, &acyclic);
59   DfsVisit(ifst1, &visitor, OutputEpsilonArcFilter<StdArc>());
60   if (acyclic) {  // no ifst1 output epsilon cycles?
61     StdOLabelLookAheadFst lfst1(ifst1);
62     StdVectorFst lfst2(ifst2);
63     LabelLookAheadRelabeler<StdArc>::Relabel(&lfst2, lfst1, true);
64     Compose(lfst1, lfst2, ofst);
65   } else {
66     Compose(ifst1, ifst2, ofst);
67   }
68 }
69
70 // This class tests a variety of identities and properties that must
71 // hold for various algorithms on weighted FSTs.
72 template <class Arc, class WeightGenerator>
73 class WeightedTester {
74  public:
75   typedef typename Arc::Label Label;
76   typedef typename Arc::StateId StateId;
77   typedef typename Arc::Weight Weight;
78
79   WeightedTester(time_t seed, const Fst<Arc> &zero_fst, const Fst<Arc> &one_fst,
80                  const Fst<Arc> &univ_fst, WeightGenerator *weight_generator)
81       : seed_(seed),
82         zero_fst_(zero_fst),
83         one_fst_(one_fst),
84         univ_fst_(univ_fst),
85         weight_generator_(weight_generator) {}
86
87   void Test(const Fst<Arc> &T1, const Fst<Arc> &T2, const Fst<Arc> &T3) {
88     TestRational(T1, T2, T3);
89     TestMap(T1);
90     TestCompose(T1, T2, T3);
91     TestSort(T1);
92     TestOptimize(T1);
93     TestSearch(T1);
94   }
95
96  private:
97   // Tests rational operations with identities
98   void TestRational(const Fst<Arc> &T1, const Fst<Arc> &T2,
99                     const Fst<Arc> &T3) {
100     {
101       VLOG(1) << "Check destructive and delayed union are equivalent.";
102       VectorFst<Arc> U1(T1);
103       Union(&U1, T2);
104       UnionFst<Arc> U2(T1, T2);
105       CHECK(Equiv(U1, U2));
106     }
107
108     {
109       VLOG(1) << "Check destructive and delayed concatenation are equivalent.";
110       VectorFst<Arc> C1(T1);
111       Concat(&C1, T2);
112       ConcatFst<Arc> C2(T1, T2);
113       CHECK(Equiv(C1, C2));
114       VectorFst<Arc> C3(T2);
115       Concat(T1, &C3);
116       CHECK(Equiv(C3, C2));
117     }
118
119     {
120       VLOG(1) << "Check destructive and delayed closure* are equivalent.";
121       VectorFst<Arc> C1(T1);
122       Closure(&C1, CLOSURE_STAR);
123       ClosureFst<Arc> C2(T1, CLOSURE_STAR);
124       CHECK(Equiv(C1, C2));
125     }
126
127     {
128       VLOG(1) << "Check destructive and delayed closure+ are equivalent.";
129       VectorFst<Arc> C1(T1);
130       Closure(&C1, CLOSURE_PLUS);
131       ClosureFst<Arc> C2(T1, CLOSURE_PLUS);
132       CHECK(Equiv(C1, C2));
133     }
134
135     {
136       VLOG(1) << "Check union is associative (destructive).";
137       VectorFst<Arc> U1(T1);
138       Union(&U1, T2);
139       Union(&U1, T3);
140
141       VectorFst<Arc> U3(T2);
142       Union(&U3, T3);
143       VectorFst<Arc> U4(T1);
144       Union(&U4, U3);
145
146       CHECK(Equiv(U1, U4));
147     }
148
149     {
150       VLOG(1) << "Check union is associative (delayed).";
151       UnionFst<Arc> U1(T1, T2);
152       UnionFst<Arc> U2(U1, T3);
153
154       UnionFst<Arc> U3(T2, T3);
155       UnionFst<Arc> U4(T1, U3);
156
157       CHECK(Equiv(U2, U4));
158     }
159
160     {
161       VLOG(1) << "Check union is associative (destructive delayed).";
162       UnionFst<Arc> U1(T1, T2);
163       Union(&U1, T3);
164
165       UnionFst<Arc> U3(T2, T3);
166       UnionFst<Arc> U4(T1, U3);
167
168       CHECK(Equiv(U1, U4));
169     }
170
171     {
172       VLOG(1) << "Check concatenation is associative (destructive).";
173       VectorFst<Arc> C1(T1);
174       Concat(&C1, T2);
175       Concat(&C1, T3);
176
177       VectorFst<Arc> C3(T2);
178       Concat(&C3, T3);
179       VectorFst<Arc> C4(T1);
180       Concat(&C4, C3);
181
182       CHECK(Equiv(C1, C4));
183     }
184
185     {
186       VLOG(1) << "Check concatenation is associative (delayed).";
187       ConcatFst<Arc> C1(T1, T2);
188       ConcatFst<Arc> C2(C1, T3);
189
190       ConcatFst<Arc> C3(T2, T3);
191       ConcatFst<Arc> C4(T1, C3);
192
193       CHECK(Equiv(C2, C4));
194     }
195
196     {
197       VLOG(1) << "Check concatenation is associative (destructive delayed).";
198       ConcatFst<Arc> C1(T1, T2);
199       Concat(&C1, T3);
200
201       ConcatFst<Arc> C3(T2, T3);
202       ConcatFst<Arc> C4(T1, C3);
203
204       CHECK(Equiv(C1, C4));
205     }
206
207     if (Weight::Properties() & kLeftSemiring) {
208       VLOG(1) << "Check concatenation left distributes"
209               << " over union (destructive).";
210
211       VectorFst<Arc> U1(T1);
212       Union(&U1, T2);
213       VectorFst<Arc> C1(T3);
214       Concat(&C1, U1);
215
216       VectorFst<Arc> C2(T3);
217       Concat(&C2, T1);
218       VectorFst<Arc> C3(T3);
219       Concat(&C3, T2);
220       VectorFst<Arc> U2(C2);
221       Union(&U2, C3);
222
223       CHECK(Equiv(C1, U2));
224     }
225
226     if (Weight::Properties() & kRightSemiring) {
227       VLOG(1) << "Check concatenation right distributes"
228               << " over union (destructive).";
229       VectorFst<Arc> U1(T1);
230       Union(&U1, T2);
231       VectorFst<Arc> C1(U1);
232       Concat(&C1, T3);
233
234       VectorFst<Arc> C2(T1);
235       Concat(&C2, T3);
236       VectorFst<Arc> C3(T2);
237       Concat(&C3, T3);
238       VectorFst<Arc> U2(C2);
239       Union(&U2, C3);
240
241       CHECK(Equiv(C1, U2));
242     }
243
244     if (Weight::Properties() & kLeftSemiring) {
245       VLOG(1) << "Check concatenation left distributes over union (delayed).";
246       UnionFst<Arc> U1(T1, T2);
247       ConcatFst<Arc> C1(T3, U1);
248
249       ConcatFst<Arc> C2(T3, T1);
250       ConcatFst<Arc> C3(T3, T2);
251       UnionFst<Arc> U2(C2, C3);
252
253       CHECK(Equiv(C1, U2));
254     }
255
256     if (Weight::Properties() & kRightSemiring) {
257       VLOG(1) << "Check concatenation right distributes over union (delayed).";
258       UnionFst<Arc> U1(T1, T2);
259       ConcatFst<Arc> C1(U1, T3);
260
261       ConcatFst<Arc> C2(T1, T3);
262       ConcatFst<Arc> C3(T2, T3);
263       UnionFst<Arc> U2(C2, C3);
264
265       CHECK(Equiv(C1, U2));
266     }
267
268     if (Weight::Properties() & kLeftSemiring) {
269       VLOG(1) << "Check T T* == T+ (destructive).";
270       VectorFst<Arc> S(T1);
271       Closure(&S, CLOSURE_STAR);
272       VectorFst<Arc> C(T1);
273       Concat(&C, S);
274
275       VectorFst<Arc> P(T1);
276       Closure(&P, CLOSURE_PLUS);
277
278       CHECK(Equiv(C, P));
279     }
280
281     if (Weight::Properties() & kRightSemiring) {
282       VLOG(1) << "Check T* T == T+ (destructive).";
283       VectorFst<Arc> S(T1);
284       Closure(&S, CLOSURE_STAR);
285       VectorFst<Arc> C(S);
286       Concat(&C, T1);
287
288       VectorFst<Arc> P(T1);
289       Closure(&P, CLOSURE_PLUS);
290
291       CHECK(Equiv(C, P));
292     }
293
294     if (Weight::Properties() & kLeftSemiring) {
295       VLOG(1) << "Check T T* == T+ (delayed).";
296       ClosureFst<Arc> S(T1, CLOSURE_STAR);
297       ConcatFst<Arc> C(T1, S);
298
299       ClosureFst<Arc> P(T1, CLOSURE_PLUS);
300
301       CHECK(Equiv(C, P));
302     }
303
304     if (Weight::Properties() & kRightSemiring) {
305       VLOG(1) << "Check T* T == T+ (delayed).";
306       ClosureFst<Arc> S(T1, CLOSURE_STAR);
307       ConcatFst<Arc> C(S, T1);
308
309       ClosureFst<Arc> P(T1, CLOSURE_PLUS);
310
311       CHECK(Equiv(C, P));
312     }
313   }
314
315   // Tests map-based operations.
316   void TestMap(const Fst<Arc> &T) {
317     {
318       VLOG(1) << "Check destructive and delayed projection are equivalent.";
319       VectorFst<Arc> P1(T);
320       Project(&P1, PROJECT_INPUT);
321       ProjectFst<Arc> P2(T, PROJECT_INPUT);
322       CHECK(Equiv(P1, P2));
323     }
324
325     {
326       VLOG(1) << "Check destructive and delayed inversion are equivalent.";
327       VectorFst<Arc> I1(T);
328       Invert(&I1);
329       InvertFst<Arc> I2(T);
330       CHECK(Equiv(I1, I2));
331     }
332
333     {
334       VLOG(1) << "Check Pi_1(T) = Pi_2(T^-1) (destructive).";
335       VectorFst<Arc> P1(T);
336       VectorFst<Arc> I1(T);
337       Project(&P1, PROJECT_INPUT);
338       Invert(&I1);
339       Project(&I1, PROJECT_OUTPUT);
340       CHECK(Equiv(P1, I1));
341     }
342
343     {
344       VLOG(1) << "Check Pi_2(T) = Pi_1(T^-1) (destructive).";
345       VectorFst<Arc> P1(T);
346       VectorFst<Arc> I1(T);
347       Project(&P1, PROJECT_OUTPUT);
348       Invert(&I1);
349       Project(&I1, PROJECT_INPUT);
350       CHECK(Equiv(P1, I1));
351     }
352
353     {
354       VLOG(1) << "Check Pi_1(T) = Pi_2(T^-1) (delayed).";
355       ProjectFst<Arc> P1(T, PROJECT_INPUT);
356       InvertFst<Arc> I1(T);
357       ProjectFst<Arc> P2(I1, PROJECT_OUTPUT);
358       CHECK(Equiv(P1, P2));
359     }
360
361     {
362       VLOG(1) << "Check Pi_2(T) = Pi_1(T^-1) (delayed).";
363       ProjectFst<Arc> P1(T, PROJECT_OUTPUT);
364       InvertFst<Arc> I1(T);
365       ProjectFst<Arc> P2(I1, PROJECT_INPUT);
366       CHECK(Equiv(P1, P2));
367     }
368
369     {
370       VLOG(1) << "Check destructive relabeling";
371       static const int kNumLabels = 10;
372       // set up relabeling pairs
373       std::vector<Label> labelset(kNumLabels);
374       for (size_t i = 0; i < kNumLabels; ++i) labelset[i] = i;
375       for (size_t i = 0; i < kNumLabels; ++i) {
376         using std::swap;
377         swap(labelset[i], labelset[rand() % kNumLabels]);
378       }
379
380       std::vector<std::pair<Label, Label>> ipairs1(kNumLabels);
381       std::vector<std::pair<Label, Label>> opairs1(kNumLabels);
382       for (size_t i = 0; i < kNumLabels; ++i) {
383         ipairs1[i] = std::make_pair(i, labelset[i]);
384         opairs1[i] = std::make_pair(labelset[i], i);
385       }
386       VectorFst<Arc> R(T);
387       Relabel(&R, ipairs1, opairs1);
388
389       std::vector<std::pair<Label, Label>> ipairs2(kNumLabels);
390       std::vector<std::pair<Label, Label>> opairs2(kNumLabels);
391       for (size_t i = 0; i < kNumLabels; ++i) {
392         ipairs2[i] = std::make_pair(labelset[i], i);
393         opairs2[i] = std::make_pair(i, labelset[i]);
394       }
395       Relabel(&R, ipairs2, opairs2);
396       CHECK(Equiv(R, T));
397
398       VLOG(1) << "Check on-the-fly relabeling";
399       RelabelFst<Arc> Rdelay(T, ipairs1, opairs1);
400
401       RelabelFst<Arc> RRdelay(Rdelay, ipairs2, opairs2);
402       CHECK(Equiv(RRdelay, T));
403     }
404
405     {
406       VLOG(1) << "Check encoding/decoding (destructive).";
407       VectorFst<Arc> D(T);
408       uint32 encode_props = 0;
409       if (rand() % 2) encode_props |= kEncodeLabels;
410       if (rand() % 2) encode_props |= kEncodeWeights;
411       EncodeMapper<Arc> encoder(encode_props, ENCODE);
412       Encode(&D, &encoder);
413       Decode(&D, encoder);
414       CHECK(Equiv(D, T));
415     }
416
417     {
418       VLOG(1) << "Check encoding/decoding (delayed).";
419       uint32 encode_props = 0;
420       if (rand() % 2) encode_props |= kEncodeLabels;
421       if (rand() % 2) encode_props |= kEncodeWeights;
422       EncodeMapper<Arc> encoder(encode_props, ENCODE);
423       EncodeFst<Arc> E(T, &encoder);
424       VectorFst<Arc> Encoded(E);
425       DecodeFst<Arc> D(Encoded, encoder);
426       CHECK(Equiv(D, T));
427     }
428
429     {
430       VLOG(1) << "Check gallic mappers (constructive).";
431       ToGallicMapper<Arc> to_mapper;
432       FromGallicMapper<Arc> from_mapper;
433       VectorFst<GallicArc<Arc>> G;
434       VectorFst<Arc> F;
435       ArcMap(T, &G, to_mapper);
436       ArcMap(G, &F, from_mapper);
437       CHECK(Equiv(T, F));
438     }
439
440     {
441       VLOG(1) << "Check gallic mappers (delayed).";
442       ToGallicMapper<Arc> to_mapper;
443       FromGallicMapper<Arc> from_mapper;
444       ArcMapFst<Arc, GallicArc<Arc>, ToGallicMapper<Arc>> G(T, to_mapper);
445       ArcMapFst<GallicArc<Arc>, Arc, FromGallicMapper<Arc>> F(G, from_mapper);
446       CHECK(Equiv(T, F));
447     }
448   }
449
450   // Tests compose-based operations.
451   void TestCompose(const Fst<Arc> &T1, const Fst<Arc> &T2, const Fst<Arc> &T3) {
452     if (!(Weight::Properties() & kCommutative)) return;
453
454     VectorFst<Arc> S1(T1);
455     VectorFst<Arc> S2(T2);
456     VectorFst<Arc> S3(T3);
457
458     ILabelCompare<Arc> icomp;
459     OLabelCompare<Arc> ocomp;
460
461     ArcSort(&S1, ocomp);
462     ArcSort(&S2, ocomp);
463     ArcSort(&S3, icomp);
464
465     {
466       VLOG(1) << "Check composition is associative.";
467       ComposeFst<Arc> C1(S1, S2);
468       ComposeFst<Arc> C2(C1, S3);
469       ComposeFst<Arc> C3(S2, S3);
470       ComposeFst<Arc> C4(S1, C3);
471
472       CHECK(Equiv(C2, C4));
473     }
474
475     {
476       VLOG(1) << "Check composition left distributes over union.";
477       UnionFst<Arc> U1(S2, S3);
478       ComposeFst<Arc> C1(S1, U1);
479
480       ComposeFst<Arc> C2(S1, S2);
481       ComposeFst<Arc> C3(S1, S3);
482       UnionFst<Arc> U2(C2, C3);
483
484       CHECK(Equiv(C1, U2));
485     }
486
487     {
488       VLOG(1) << "Check composition right distributes over union.";
489       UnionFst<Arc> U1(S1, S2);
490       ComposeFst<Arc> C1(U1, S3);
491
492       ComposeFst<Arc> C2(S1, S3);
493       ComposeFst<Arc> C3(S2, S3);
494       UnionFst<Arc> U2(C2, C3);
495
496       CHECK(Equiv(C1, U2));
497     }
498
499     VectorFst<Arc> A1(S1);
500     VectorFst<Arc> A2(S2);
501     VectorFst<Arc> A3(S3);
502     Project(&A1, PROJECT_OUTPUT);
503     Project(&A2, PROJECT_INPUT);
504     Project(&A3, PROJECT_INPUT);
505
506     {
507       VLOG(1) << "Check intersection is commutative.";
508       IntersectFst<Arc> I1(A1, A2);
509       IntersectFst<Arc> I2(A2, A1);
510       CHECK(Equiv(I1, I2));
511     }
512
513     {
514       VLOG(1) << "Check all epsilon filters leads to equivalent results.";
515       typedef Matcher<Fst<Arc>> M;
516       ComposeFst<Arc> C1(S1, S2);
517       ComposeFst<Arc> C2(
518           S1, S2, ComposeFstOptions<Arc, M, AltSequenceComposeFilter<M>>());
519       ComposeFst<Arc> C3(S1, S2,
520                          ComposeFstOptions<Arc, M, MatchComposeFilter<M>>());
521
522       CHECK(Equiv(C1, C2));
523       CHECK(Equiv(C1, C3));
524
525       if ((Weight::Properties() & kIdempotent) ||
526           S1.Properties(kNoOEpsilons, false) ||
527           S2.Properties(kNoIEpsilons, false)) {
528         ComposeFst<Arc> C4(
529             S1, S2, ComposeFstOptions<Arc, M, TrivialComposeFilter<M>>());
530         CHECK(Equiv(C1, C4));
531       }
532
533       if (S1.Properties(kNoOEpsilons, false) &&
534           S2.Properties(kNoIEpsilons, false)) {
535         ComposeFst<Arc> C5(S1, S2,
536                            ComposeFstOptions<Arc, M, NullComposeFilter<M>>());
537         CHECK(Equiv(C1, C5));
538       }
539     }
540
541     {
542       VLOG(1) << "Check look-ahead filters lead to equivalent results.";
543       VectorFst<Arc> C1, C2;
544       Compose(S1, S2, &C1);
545       LookAheadCompose(S1, S2, &C2);
546       CHECK(Equiv(C1, C2));
547     }
548   }
549
550   // Tests sorting operations
551   void TestSort(const Fst<Arc> &T) {
552     ILabelCompare<Arc> icomp;
553     OLabelCompare<Arc> ocomp;
554
555     {
556       VLOG(1) << "Check arc sorted Fst is equivalent to its input.";
557       VectorFst<Arc> S1(T);
558       ArcSort(&S1, icomp);
559       CHECK(Equiv(T, S1));
560     }
561
562     {
563       VLOG(1) << "Check destructive and delayed arcsort are equivalent.";
564       VectorFst<Arc> S1(T);
565       ArcSort(&S1, icomp);
566       ArcSortFst<Arc, ILabelCompare<Arc>> S2(T, icomp);
567       CHECK(Equiv(S1, S2));
568     }
569
570     {
571       VLOG(1) << "Check ilabel sorting vs. olabel sorting with inversions.";
572       VectorFst<Arc> S1(T);
573       VectorFst<Arc> S2(T);
574       ArcSort(&S1, icomp);
575       Invert(&S2);
576       ArcSort(&S2, ocomp);
577       Invert(&S2);
578       CHECK(Equiv(S1, S2));
579     }
580
581     {
582       VLOG(1) << "Check topologically sorted Fst is equivalent to its input.";
583       VectorFst<Arc> S1(T);
584       TopSort(&S1);
585       CHECK(Equiv(T, S1));
586     }
587
588     {
589       VLOG(1) << "Check reverse(reverse(T)) = T";
590       for (int i = 0; i < 2; ++i) {
591         VectorFst<ReverseArc<Arc>> R1;
592         VectorFst<Arc> R2;
593         bool require_superinitial = i == 1;
594         Reverse(T, &R1, require_superinitial);
595         Reverse(R1, &R2, require_superinitial);
596         CHECK(Equiv(T, R2));
597       }
598     }
599   }
600
601   // Tests optimization operations
602   void TestOptimize(const Fst<Arc> &T) {
603     uint64 tprops = T.Properties(kFstProperties, true);
604     uint64 wprops = Weight::Properties();
605
606     VectorFst<Arc> A(T);
607     Project(&A, PROJECT_INPUT);
608     {
609       VLOG(1) << "Check connected FST is equivalent to its input.";
610       VectorFst<Arc> C1(T);
611       Connect(&C1);
612       CHECK(Equiv(T, C1));
613     }
614
615     if ((wprops & kSemiring) == kSemiring &&
616         (tprops & kAcyclic || wprops & kIdempotent)) {
617       VLOG(1) << "Check epsilon-removed FST is equivalent to its input.";
618       VectorFst<Arc> R1(T);
619       RmEpsilon(&R1);
620       CHECK(Equiv(T, R1));
621
622       VLOG(1) << "Check destructive and delayed epsilon removal"
623               << "are equivalent.";
624       RmEpsilonFst<Arc> R2(T);
625       CHECK(Equiv(R1, R2));
626
627       VLOG(1) << "Check an FST with a large proportion"
628               << " of epsilon transitions:";
629       // Maps all transitions of T to epsilon-transitions and append
630       // a non-epsilon transition.
631       VectorFst<Arc> U;
632       ArcMap(T, &U, EpsMapper<Arc>());
633       VectorFst<Arc> V;
634       V.SetStart(V.AddState());
635       Arc arc(1, 1, Weight::One(), V.AddState());
636       V.AddArc(V.Start(), arc);
637       V.SetFinal(arc.nextstate, Weight::One());
638       Concat(&U, V);
639       // Check that epsilon-removal preserves the shortest-distance
640       // from the initial state to the final states.
641       std::vector<Weight> d;
642       ShortestDistance(U, &d, true);
643       Weight w = U.Start() < d.size() ? d[U.Start()] : Weight::Zero();
644       VectorFst<Arc> U1(U);
645       RmEpsilon(&U1);
646       ShortestDistance(U1, &d, true);
647       Weight w1 = U1.Start() < d.size() ? d[U1.Start()] : Weight::Zero();
648       CHECK(ApproxEqual(w, w1, kTestDelta));
649       RmEpsilonFst<Arc> U2(U);
650       ShortestDistance(U2, &d, true);
651       Weight w2 = U2.Start() < d.size() ? d[U2.Start()] : Weight::Zero();
652       CHECK(ApproxEqual(w, w2, kTestDelta));
653     }
654
655     if ((wprops & kSemiring) == kSemiring && tprops & kAcyclic) {
656       VLOG(1) << "Check determinized FSA is equivalent to its input.";
657       DeterminizeFst<Arc> D(A);
658       CHECK(Equiv(A, D));
659
660       {
661         VLOG(1) << "Check determinized FST is equivalent to its input.";
662         DeterminizeFstOptions<Arc> opts;
663         opts.type = DETERMINIZE_NONFUNCTIONAL;
664         DeterminizeFst<Arc> DT(T, opts);
665         CHECK(Equiv(T, DT));
666       }
667
668       if ((wprops & (kPath | kCommutative)) == (kPath | kCommutative)) {
669         VLOG(1) << "Check pruning in determinization";
670         VectorFst<Arc> P;
671         Weight threshold = (*weight_generator_)();
672         DeterminizeOptions<Arc> opts;
673         opts.weight_threshold = threshold;
674         Determinize(A, &P, opts);
675         CHECK(P.Properties(kIDeterministic, true));
676         CHECK(PruneEquiv(A, P, threshold));
677       }
678
679       if ((wprops & kPath) == kPath) {
680         VLOG(1) << "Check min-determinization";
681
682         // Ensures no input epsilons
683         VectorFst<Arc> R(T);
684         std::vector<std::pair<Label, Label>> ipairs, opairs;
685         ipairs.push_back(std::pair<Label, Label>(0, 1));
686         Relabel(&R, ipairs, opairs);
687
688         VectorFst<Arc> M;
689         DeterminizeOptions<Arc> opts;
690         opts.type = DETERMINIZE_DISAMBIGUATE;
691         Determinize(R, &M, opts);
692         CHECK(M.Properties(kIDeterministic, true));
693         CHECK(MinRelated(M, R));
694       }
695
696       int n;
697       {
698         VLOG(1) << "Check size(min(det(A))) <= size(det(A))"
699                 << " and  min(det(A)) equiv det(A)";
700         VectorFst<Arc> M(D);
701         n = M.NumStates();
702         Minimize(&M);
703         CHECK(Equiv(D, M));
704         CHECK(M.NumStates() <= n);
705         n = M.NumStates();
706       }
707
708       if (n && (wprops & kIdempotent) == kIdempotent &&
709           A.Properties(kNoEpsilons, true)) {
710         VLOG(1) << "Check that Revuz's algorithm leads to the"
711                 << " same number of states as Brozozowski's algorithm";
712
713         // Skip test if A is the empty machine or contains epsilons or
714         // if the semiring is not idempotent (to avoid floating point
715         // errors)
716         VectorFst<Arc> R;
717         Reverse(A, &R);
718         RmEpsilon(&R);
719         DeterminizeFst<Arc> DR(R);
720         VectorFst<Arc> RD;
721         Reverse(DR, &RD);
722         DeterminizeFst<Arc> DRD(RD);
723         VectorFst<Arc> M(DRD);
724         CHECK_EQ(n + 1, M.NumStates());  // Accounts for the epsilon transition
725                                          // to the initial state
726       }
727     }
728
729     if ((wprops & kSemiring) == kSemiring && tprops & kAcyclic) {
730       VLOG(1) << "Check disambiguated FSA is equivalent to its input.";
731       VectorFst<Arc> R(A), D;
732       RmEpsilon(&R);
733       Disambiguate(R, &D);
734       CHECK(Equiv(R, D));
735       VLOG(1) << "Check disambiguated FSA is unambiguous";
736       CHECK(Unambiguous(D));
737
738       /* TODO(riley): find out why this fails
739       if ((wprops & (kPath | kCommutative)) == (kPath | kCommutative)) {
740         VLOG(1)  << "Check pruning in disambiguation";
741         VectorFst<Arc> P;
742         Weight threshold = (*weight_generator_)();
743         DisambiguateOptions<Arc> opts;
744         opts.weight_threshold = threshold;
745         Disambiguate(R, &P, opts);
746         CHECK(Unambiguous(P));
747         CHECK(PruneEquiv(A, P, threshold));
748       }
749       */
750     }
751
752     if (Arc::Type() == LogArc::Type() || Arc::Type() == StdArc::Type()) {
753       VLOG(1) << "Check reweight(T) equiv T";
754       std::vector<Weight> potential;
755       VectorFst<Arc> RI(T);
756       VectorFst<Arc> RF(T);
757       while (potential.size() < RI.NumStates())
758         potential.push_back((*weight_generator_)());
759
760       Reweight(&RI, potential, REWEIGHT_TO_INITIAL);
761       CHECK(Equiv(T, RI));
762
763       Reweight(&RF, potential, REWEIGHT_TO_FINAL);
764       CHECK(Equiv(T, RF));
765     }
766
767     if ((wprops & kIdempotent) || (tprops & kAcyclic)) {
768       VLOG(1) << "Check pushed FST is equivalent to input FST.";
769       // Pushing towards the final state.
770       if (wprops & kRightSemiring) {
771         VectorFst<Arc> P1;
772         Push<Arc, REWEIGHT_TO_FINAL>(T, &P1, kPushLabels);
773         CHECK(Equiv(T, P1));
774
775         VectorFst<Arc> P2;
776         Push<Arc, REWEIGHT_TO_FINAL>(T, &P2, kPushWeights);
777         CHECK(Equiv(T, P2));
778
779         VectorFst<Arc> P3;
780         Push<Arc, REWEIGHT_TO_FINAL>(T, &P3, kPushLabels | kPushWeights);
781         CHECK(Equiv(T, P3));
782       }
783
784       // Pushing towards the initial state.
785       if (wprops & kLeftSemiring) {
786         VectorFst<Arc> P1;
787         Push<Arc, REWEIGHT_TO_INITIAL>(T, &P1, kPushLabels);
788         CHECK(Equiv(T, P1));
789
790         VectorFst<Arc> P2;
791         Push<Arc, REWEIGHT_TO_INITIAL>(T, &P2, kPushWeights);
792         CHECK(Equiv(T, P2));
793         VectorFst<Arc> P3;
794         Push<Arc, REWEIGHT_TO_INITIAL>(T, &P3, kPushLabels | kPushWeights);
795         CHECK(Equiv(T, P3));
796       }
797     }
798
799     if ((wprops & (kPath | kCommutative)) == (kPath | kCommutative)) {
800       VLOG(1) << "Check pruning algorithm";
801       {
802         VLOG(1) << "Check equiv. of constructive and destructive algorithms";
803         Weight thresold = (*weight_generator_)();
804         VectorFst<Arc> P1(T);
805         Prune(&P1, thresold);
806         VectorFst<Arc> P2;
807         Prune(T, &P2, thresold);
808         CHECK(Equiv(P1, P2));
809       }
810
811       {
812         VLOG(1) << "Check prune(reverse) equiv reverse(prune)";
813         Weight thresold = (*weight_generator_)();
814         VectorFst<ReverseArc<Arc>> R;
815         VectorFst<Arc> P1(T);
816         VectorFst<Arc> P2;
817         Prune(&P1, thresold);
818         Reverse(T, &R);
819         Prune(&R, thresold.Reverse());
820         Reverse(R, &P2);
821         CHECK(Equiv(P1, P2));
822       }
823       {
824         VLOG(1) << "Check: ShortestDistance(A - prune(A))"
825                 << " > ShortestDistance(A) times Threshold";
826         Weight threshold = (*weight_generator_)();
827         VectorFst<Arc> P;
828         Prune(A, &P, threshold);
829         CHECK(PruneEquiv(A, P, threshold));
830       }
831     }
832     if (tprops & kAcyclic) {
833       VLOG(1) << "Check synchronize(T) equiv T";
834       SynchronizeFst<Arc> S(T);
835       CHECK(Equiv(T, S));
836     }
837   }
838
839   // Tests search operations
840   void TestSearch(const Fst<Arc> &T) {
841     uint64 wprops = Weight::Properties();
842
843     VectorFst<Arc> A(T);
844     Project(&A, PROJECT_INPUT);
845
846     if ((wprops & (kPath | kRightSemiring)) == (kPath | kRightSemiring)) {
847       VLOG(1) << "Check 1-best weight.";
848       VectorFst<Arc> path;
849       ShortestPath(T, &path);
850       Weight tsum = ShortestDistance(T);
851       Weight psum = ShortestDistance(path);
852       CHECK(ApproxEqual(tsum, psum, kTestDelta));
853     }
854
855     if ((wprops & (kPath | kSemiring)) == (kPath | kSemiring)) {
856       VLOG(1) << "Check n-best weights";
857       VectorFst<Arc> R(A);
858       RmEpsilon(&R);
859       int nshortest = rand() % kNumRandomShortestPaths + 2;
860       VectorFst<Arc> paths;
861       ShortestPath(R, &paths, nshortest, true, false, Weight::Zero(),
862                    kNumShortestStates);
863       std::vector<Weight> distance;
864       ShortestDistance(paths, &distance, true);
865       StateId pstart = paths.Start();
866       if (pstart != kNoStateId) {
867         ArcIterator<Fst<Arc>> piter(paths, pstart);
868         for (; !piter.Done(); piter.Next()) {
869           StateId s = piter.Value().nextstate;
870           Weight nsum = s < distance.size()
871                             ? Times(piter.Value().weight, distance[s])
872                             : Weight::Zero();
873           VectorFst<Arc> path;
874           ShortestPath(R, &path);
875           Weight dsum = ShortestDistance(path);
876           CHECK(ApproxEqual(nsum, dsum, kTestDelta));
877           ArcMap(&path, RmWeightMapper<Arc>());
878           VectorFst<Arc> S;
879           Difference(R, path, &S);
880           R = S;
881         }
882       }
883     }
884   }
885
886   // Tests if two FSTS are equivalent by checking if random
887   // strings from one FST are transduced the same by both FSTs.
888   template <class A>
889   bool Equiv(const Fst<A> &fst1, const Fst<A> &fst2) {
890     VLOG(1) << "Check FSTs for sanity (including property bits).";
891     CHECK(Verify(fst1));
892     CHECK(Verify(fst2));
893
894     // Ensures seed used once per instantiation.
895     static UniformArcSelector<A> uniform_selector(seed_);
896     RandGenOptions<UniformArcSelector<A>> opts(uniform_selector,
897                                                kRandomPathLength);
898     return RandEquivalent(fst1, fst2, kNumRandomPaths, kTestDelta, opts);
899   }
900
901   // Tests FSA is unambiguous
902   bool Unambiguous(const Fst<Arc> &fst) {
903     VectorFst<StdArc> sfst, dfst;
904     VectorFst<LogArc> lfst1, lfst2;
905     Map(fst, &sfst, RmWeightMapper<Arc, StdArc>());
906     Determinize(sfst, &dfst);
907     Map(fst, &lfst1, RmWeightMapper<Arc, LogArc>());
908     Map(dfst, &lfst2, RmWeightMapper<StdArc, LogArc>());
909     return Equiv(lfst1, lfst2);
910   }
911
912   // Ensures input-epsilon free transducers fst1 and fst2 have the
913   // same domain and that for each string pair '(is, os)' in fst1,
914   // '(is, os)' is the minimum weight match to 'is' in fst2.
915   template <class A>
916   bool MinRelated(const Fst<A> &fst1, const Fst<A> &fst2) {
917     // Same domain
918     VectorFst<Arc> P1(fst1), P2(fst2);
919     Project(&P1, PROJECT_INPUT);
920     Project(&P2, PROJECT_INPUT);
921     if (!Equiv(P1, P2)) {
922       LOG(ERROR) << "Inputs not equivalent";
923       return false;
924     }
925
926     // Ensures seed used once per instantiation.
927     static UniformArcSelector<A> uniform_selector(seed_);
928     RandGenOptions<UniformArcSelector<A>> opts(uniform_selector,
929                                                kRandomPathLength);
930
931     VectorFst<Arc> path, paths1, paths2;
932     for (ssize_t n = 0; n < kNumRandomPaths; ++n) {
933       RandGen(fst1, &path, opts);
934       Invert(&path);
935       Map(&path, RmWeightMapper<Arc>());
936       Compose(path, fst2, &paths1);
937       Weight sum1 = ShortestDistance(paths1);
938       Compose(paths1, path, &paths2);
939       Weight sum2 = ShortestDistance(paths2);
940       if (!ApproxEqual(Plus(sum1, sum2), sum2, kTestDelta)) {
941         LOG(ERROR) << "Sums not equivalent: " << sum1 << " " << sum2;
942         return false;
943       }
944     }
945     return true;
946   }
947
948   // Tests ShortestDistance(A - P) >=
949   // ShortestDistance(A) times Threshold.
950   template <class A>
951   bool PruneEquiv(const Fst<A> &fst, const Fst<A> &pfst, Weight threshold) {
952     VLOG(1) << "Check FSTs for sanity (including property bits).";
953     CHECK(Verify(fst));
954     CHECK(Verify(pfst));
955
956     DifferenceFst<Arc> D(fst, DeterminizeFst<Arc>(RmEpsilonFst<Arc>(
957                                   ArcMapFst<Arc, Arc, RmWeightMapper<Arc>>(
958                                       pfst, RmWeightMapper<Arc>()))));
959     Weight sum1 = Times(ShortestDistance(fst), threshold);
960     Weight sum2 = ShortestDistance(D);
961     return ApproxEqual(Plus(sum1, sum2), sum1, kTestDelta);
962   }
963
964   // Random seed.
965   int seed_;
966   // FST with no states
967   VectorFst<Arc> zero_fst_;
968   // FST with one state that accepts epsilon.
969   VectorFst<Arc> one_fst_;
970   // FST with one state that accepts all strings.
971   VectorFst<Arc> univ_fst_;
972   // Generates weights used in testing.
973   WeightGenerator *weight_generator_;
974   // Maximum random path length.
975   static const int kRandomPathLength;
976   // Number of random paths to explore.
977   static const int kNumRandomPaths;
978   // Maximum number of nshortest paths.
979   static const int kNumRandomShortestPaths;
980   // Maximum number of nshortest states.
981   static const int kNumShortestStates;
982   // Delta for equivalence tests.
983   static const float kTestDelta;
984
985   WeightedTester(const WeightedTester &) = delete;
986   WeightedTester &operator=(const WeightedTester &) = delete;
987 };
988
989 template <class A, class WG>
990 const int WeightedTester<A, WG>::kRandomPathLength = 25;
991
992 template <class A, class WG>
993 const int WeightedTester<A, WG>::kNumRandomPaths = 100;
994
995 template <class A, class WG>
996 const int WeightedTester<A, WG>::kNumRandomShortestPaths = 100;
997
998 template <class A, class WG>
999 const int WeightedTester<A, WG>::kNumShortestStates = 10000;
1000
1001 template <class A, class WG>
1002 const float WeightedTester<A, WG>::kTestDelta = .05;
1003
1004 // This class tests a variety of identities and properties that must
1005 // hold for various algorithms on unweighted FSAs and that are not tested
1006 // by WeightedTester. Only the specialization does anything interesting.
1007 template <class Arc>
1008 class UnweightedTester {
1009  public:
1010   UnweightedTester(const Fst<Arc> &zero_fsa, const Fst<Arc> &one_fsa,
1011                    const Fst<Arc> &univ_fsa) {}
1012
1013   void Test(const Fst<Arc> &A1, const Fst<Arc> &A2, const Fst<Arc> &A3) {}
1014 };
1015
1016 // Specialization for StdArc. This should work for any commutative,
1017 // idempotent semiring when restricted to the unweighted case
1018 // (being isomorphic to the boolean semiring).
1019 template <>
1020 class UnweightedTester<StdArc> {
1021  public:
1022   typedef StdArc Arc;
1023   typedef Arc::Label Label;
1024   typedef Arc::StateId StateId;
1025   typedef Arc::Weight Weight;
1026
1027   UnweightedTester(const Fst<Arc> &zero_fsa, const Fst<Arc> &one_fsa,
1028                    const Fst<Arc> &univ_fsa)
1029       : zero_fsa_(zero_fsa), one_fsa_(one_fsa), univ_fsa_(univ_fsa) {}
1030
1031   void Test(const Fst<Arc> &A1, const Fst<Arc> &A2, const Fst<Arc> &A3) {
1032     TestRational(A1, A2, A3);
1033     TestIntersect(A1, A2, A3);
1034     TestOptimize(A1);
1035   }
1036
1037  private:
1038   // Tests rational operations with identities
1039   void TestRational(const Fst<Arc> &A1, const Fst<Arc> &A2,
1040                     const Fst<Arc> &A3) {
1041     {
1042       VLOG(1) << "Check the union contains its arguments (destructive).";
1043       VectorFst<Arc> U(A1);
1044       Union(&U, A2);
1045
1046       CHECK(Subset(A1, U));
1047       CHECK(Subset(A2, U));
1048     }
1049
1050     {
1051       VLOG(1) << "Check the union contains its arguments (delayed).";
1052       UnionFst<Arc> U(A1, A2);
1053
1054       CHECK(Subset(A1, U));
1055       CHECK(Subset(A2, U));
1056     }
1057
1058     {
1059       VLOG(1) << "Check if A^n c A* (destructive).";
1060       VectorFst<Arc> C(one_fsa_);
1061       int n = rand() % 5;
1062       for (int i = 0; i < n; ++i) Concat(&C, A1);
1063
1064       VectorFst<Arc> S(A1);
1065       Closure(&S, CLOSURE_STAR);
1066       CHECK(Subset(C, S));
1067     }
1068
1069     {
1070       VLOG(1) << "Check if A^n c A* (delayed).";
1071       int n = rand() % 5;
1072       Fst<Arc> *C = new VectorFst<Arc>(one_fsa_);
1073       for (int i = 0; i < n; ++i) {
1074         ConcatFst<Arc> *F = new ConcatFst<Arc>(*C, A1);
1075         delete C;
1076         C = F;
1077       }
1078       ClosureFst<Arc> S(A1, CLOSURE_STAR);
1079       CHECK(Subset(*C, S));
1080       delete C;
1081     }
1082   }
1083
1084   // Tests intersect-based operations.
1085   void TestIntersect(const Fst<Arc> &A1, const Fst<Arc> &A2,
1086                      const Fst<Arc> &A3) {
1087     VectorFst<Arc> S1(A1);
1088     VectorFst<Arc> S2(A2);
1089     VectorFst<Arc> S3(A3);
1090
1091     ILabelCompare<Arc> comp;
1092
1093     ArcSort(&S1, comp);
1094     ArcSort(&S2, comp);
1095     ArcSort(&S3, comp);
1096
1097     {
1098       VLOG(1) << "Check the intersection is contained in its arguments.";
1099       IntersectFst<Arc> I1(S1, S2);
1100       CHECK(Subset(I1, S1));
1101       CHECK(Subset(I1, S2));
1102     }
1103
1104     {
1105       VLOG(1) << "Check union distributes over intersection.";
1106       IntersectFst<Arc> I1(S1, S2);
1107       UnionFst<Arc> U1(I1, S3);
1108
1109       UnionFst<Arc> U2(S1, S3);
1110       UnionFst<Arc> U3(S2, S3);
1111       ArcSortFst<Arc, ILabelCompare<Arc>> S4(U3, comp);
1112       IntersectFst<Arc> I2(U2, S4);
1113
1114       CHECK(Equiv(U1, I2));
1115     }
1116
1117     VectorFst<Arc> C1;
1118     VectorFst<Arc> C2;
1119     Complement(S1, &C1);
1120     Complement(S2, &C2);
1121     ArcSort(&C1, comp);
1122     ArcSort(&C2, comp);
1123
1124     {
1125       VLOG(1) << "Check S U S' = Sigma*";
1126       UnionFst<Arc> U(S1, C1);
1127       CHECK(Equiv(U, univ_fsa_));
1128     }
1129
1130     {
1131       VLOG(1) << "Check S n S' = {}";
1132       IntersectFst<Arc> I(S1, C1);
1133       CHECK(Equiv(I, zero_fsa_));
1134     }
1135
1136     {
1137       VLOG(1) << "Check (S1' U S2') == (S1 n S2)'";
1138       UnionFst<Arc> U(C1, C2);
1139
1140       IntersectFst<Arc> I(S1, S2);
1141       VectorFst<Arc> C3;
1142       Complement(I, &C3);
1143       CHECK(Equiv(U, C3));
1144     }
1145
1146     {
1147       VLOG(1) << "Check (S1' n S2') == (S1 U S2)'";
1148       IntersectFst<Arc> I(C1, C2);
1149
1150       UnionFst<Arc> U(S1, S2);
1151       VectorFst<Arc> C3;
1152       Complement(U, &C3);
1153       CHECK(Equiv(I, C3));
1154     }
1155   }
1156
1157   // Tests optimization operations
1158   void TestOptimize(const Fst<Arc> &A) {
1159     {
1160       VLOG(1) << "Check determinized FSA is equivalent to its input.";
1161       DeterminizeFst<Arc> D(A);
1162       CHECK(Equiv(A, D));
1163     }
1164
1165     {
1166       VLOG(1) << "Check disambiguated FSA is equivalent to its input.";
1167       VectorFst<Arc> R(A), D;
1168       RmEpsilon(&R);
1169
1170       Disambiguate(R, &D);
1171       CHECK(Equiv(R, D));
1172     }
1173
1174     {
1175       VLOG(1) << "Check minimized FSA is equivalent to its input.";
1176       int n;
1177       {
1178         RmEpsilonFst<Arc> R(A);
1179         DeterminizeFst<Arc> D(R);
1180         VectorFst<Arc> M(D);
1181         Minimize(&M);
1182         CHECK(Equiv(A, M));
1183         n = M.NumStates();
1184       }
1185
1186       if (n) {  // Skip test if A is the empty machine
1187         VLOG(1) << "Check that Hopcroft's and Revuz's algorithms lead to the"
1188                 << " same number of states as Brozozowski's algorithm";
1189         VectorFst<Arc> R;
1190         Reverse(A, &R);
1191         RmEpsilon(&R);
1192         DeterminizeFst<Arc> DR(R);
1193         VectorFst<Arc> RD;
1194         Reverse(DR, &RD);
1195         DeterminizeFst<Arc> DRD(RD);
1196         VectorFst<Arc> M(DRD);
1197         CHECK_EQ(n + 1, M.NumStates());  // Accounts for the epsilon transition
1198                                          // to the initial state
1199       }
1200     }
1201   }
1202
1203   // Tests if two FSAS are equivalent.
1204   bool Equiv(const Fst<Arc> &fsa1, const Fst<Arc> &fsa2) {
1205     VLOG(1) << "Check FSAs for sanity (including property bits).";
1206     CHECK(Verify(fsa1));
1207     CHECK(Verify(fsa2));
1208
1209     VectorFst<Arc> vfsa1(fsa1);
1210     VectorFst<Arc> vfsa2(fsa2);
1211     RmEpsilon(&vfsa1);
1212     RmEpsilon(&vfsa2);
1213     DeterminizeFst<Arc> dfa1(vfsa1);
1214     DeterminizeFst<Arc> dfa2(vfsa2);
1215
1216     // Test equivalence using union-find algorithm
1217     bool equiv1 = Equivalent(dfa1, dfa2);
1218
1219     // Test equivalence by checking if (S1 - S2) U (S2 - S1) is empty
1220     ILabelCompare<Arc> comp;
1221     VectorFst<Arc> sdfa1(dfa1);
1222     ArcSort(&sdfa1, comp);
1223     VectorFst<Arc> sdfa2(dfa2);
1224     ArcSort(&sdfa2, comp);
1225
1226     DifferenceFst<Arc> dfsa1(sdfa1, sdfa2);
1227     DifferenceFst<Arc> dfsa2(sdfa2, sdfa1);
1228
1229     VectorFst<Arc> ufsa(dfsa1);
1230     Union(&ufsa, dfsa2);
1231     Connect(&ufsa);
1232     bool equiv2 = ufsa.NumStates() == 0;
1233
1234     // Check two equivalence tests match
1235     CHECK((equiv1 && equiv2) || (!equiv1 && !equiv2));
1236
1237     return equiv1;
1238   }
1239
1240   // Tests if FSA1 is a subset of FSA2 (disregarding weights).
1241   bool Subset(const Fst<Arc> &fsa1, const Fst<Arc> &fsa2) {
1242     VLOG(1) << "Check FSAs (incl. property bits) for sanity";
1243     CHECK(Verify(fsa1));
1244     CHECK(Verify(fsa2));
1245
1246     VectorFst<StdArc> vfsa1;
1247     VectorFst<StdArc> vfsa2;
1248     RmEpsilon(&vfsa1);
1249     RmEpsilon(&vfsa2);
1250     ILabelCompare<StdArc> comp;
1251     ArcSort(&vfsa1, comp);
1252     ArcSort(&vfsa2, comp);
1253     IntersectFst<StdArc> ifsa(vfsa1, vfsa2);
1254     DeterminizeFst<StdArc> dfa1(vfsa1);
1255     DeterminizeFst<StdArc> dfa2(ifsa);
1256     return Equivalent(dfa1, dfa2);
1257   }
1258
1259   // Returns complement Fsa
1260   void Complement(const Fst<Arc> &ifsa, MutableFst<Arc> *ofsa) {
1261     RmEpsilonFst<Arc> rfsa(ifsa);
1262     DeterminizeFst<Arc> dfa(rfsa);
1263     DifferenceFst<Arc> cfsa(univ_fsa_, dfa);
1264     *ofsa = cfsa;
1265   }
1266
1267   // FSA with no states
1268   VectorFst<Arc> zero_fsa_;
1269
1270   // FSA with one state that accepts epsilon.
1271   VectorFst<Arc> one_fsa_;
1272
1273   // FSA with one state that accepts all strings.
1274   VectorFst<Arc> univ_fsa_;
1275 };
1276
1277 // This class tests a variety of identities and properties that must
1278 // hold for various FST algorithms. It randomly generates FSTs, using
1279 // function object 'weight_generator' to select weights. 'WeightTester'
1280 // and 'UnweightedTester' are then called.
1281 template <class Arc, class WeightGenerator>
1282 class AlgoTester {
1283  public:
1284   typedef typename Arc::Label Label;
1285   typedef typename Arc::StateId StateId;
1286   typedef typename Arc::Weight Weight;
1287
1288   AlgoTester(WeightGenerator generator, int seed)
1289       : weight_generator_(generator) {
1290     one_fst_.AddState();
1291     one_fst_.SetStart(0);
1292     one_fst_.SetFinal(0, Weight::One());
1293
1294     univ_fst_.AddState();
1295     univ_fst_.SetStart(0);
1296     univ_fst_.SetFinal(0, Weight::One());
1297     for (int i = 0; i < kNumRandomLabels; ++i)
1298       univ_fst_.AddArc(0, Arc(i, i, Weight::One(), 0));
1299
1300     weighted_tester_ = new WeightedTester<Arc, WeightGenerator>(
1301         seed, zero_fst_, one_fst_, univ_fst_, &weight_generator_);
1302
1303     unweighted_tester_ =
1304         new UnweightedTester<Arc>(zero_fst_, one_fst_, univ_fst_);
1305   }
1306
1307   ~AlgoTester() {
1308     delete weighted_tester_;
1309     delete unweighted_tester_;
1310   }
1311
1312   void MakeRandFst(MutableFst<Arc> *fst) {
1313     RandFst<Arc, WeightGenerator>(kNumRandomStates, kNumRandomArcs,
1314                                   kNumRandomLabels, kAcyclicProb,
1315                                   &weight_generator_, fst);
1316   }
1317
1318   void Test() {
1319     VLOG(1) << "weight type = " << Weight::Type();
1320
1321     for (int i = 0; i < FLAGS_repeat; ++i) {
1322       // Random transducers
1323       VectorFst<Arc> T1;
1324       VectorFst<Arc> T2;
1325       VectorFst<Arc> T3;
1326       MakeRandFst(&T1);
1327       MakeRandFst(&T2);
1328       MakeRandFst(&T3);
1329       weighted_tester_->Test(T1, T2, T3);
1330
1331       VectorFst<Arc> A1(T1);
1332       VectorFst<Arc> A2(T2);
1333       VectorFst<Arc> A3(T3);
1334       Project(&A1, PROJECT_OUTPUT);
1335       Project(&A2, PROJECT_INPUT);
1336       Project(&A3, PROJECT_INPUT);
1337       ArcMap(&A1, rm_weight_mapper_);
1338       ArcMap(&A2, rm_weight_mapper_);
1339       ArcMap(&A3, rm_weight_mapper_);
1340       unweighted_tester_->Test(A1, A2, A3);
1341     }
1342   }
1343
1344  private:
1345   // Generates weights used in testing.
1346   WeightGenerator weight_generator_;
1347
1348   // FST with no states
1349   VectorFst<Arc> zero_fst_;
1350
1351   // FST with one state that accepts epsilon.
1352   VectorFst<Arc> one_fst_;
1353
1354   // FST with one state that accepts all strings.
1355   VectorFst<Arc> univ_fst_;
1356
1357   // Tests weighted FSTs
1358   WeightedTester<Arc, WeightGenerator> *weighted_tester_;
1359
1360   // Tests unweighted FSTs
1361   UnweightedTester<Arc> *unweighted_tester_;
1362
1363   // Mapper to remove weights from an Fst
1364   RmWeightMapper<Arc> rm_weight_mapper_;
1365
1366   // Maximum number of states in random test Fst.
1367   static const int kNumRandomStates;
1368
1369   // Maximum number of arcs in random test Fst.
1370   static const int kNumRandomArcs;
1371
1372   // Number of alternative random labels.
1373   static const int kNumRandomLabels;
1374
1375   // Probability to force an acyclic Fst
1376   static const float kAcyclicProb;
1377
1378   // Maximum random path length.
1379   static const int kRandomPathLength;
1380
1381   // Number of random paths to explore.
1382   static const int kNumRandomPaths;
1383
1384   AlgoTester(const AlgoTester &) = delete;
1385   AlgoTester &operator=(const AlgoTester &) = delete;
1386 };
1387
1388 template <class A, class G>
1389 const int AlgoTester<A, G>::kNumRandomStates = 10;
1390
1391 template <class A, class G>
1392 const int AlgoTester<A, G>::kNumRandomArcs = 25;
1393
1394 template <class A, class G>
1395 const int AlgoTester<A, G>::kNumRandomLabels = 5;
1396
1397 template <class A, class G>
1398 const float AlgoTester<A, G>::kAcyclicProb = .25;
1399
1400 template <class A, class G>
1401 const int AlgoTester<A, G>::kRandomPathLength = 25;
1402
1403 template <class A, class G>
1404 const int AlgoTester<A, G>::kNumRandomPaths = 100;
1405
1406 }  // namespace fst
1407
1408 #endif  // FST_TEST_ALGO_TEST_H_