1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Regression test for various FST algorithms.
6 #ifndef FST_TEST_ALGO_TEST_H_
7 #define FST_TEST_ALGO_TEST_H_
11 #include <fst/fstlib.h>
12 #include "./rand-fst.h"
14 DECLARE_int32(repeat); // defined in ./algo_test.cc
18 // Mapper to change input and output label of every transition into
25 A operator()(const A &arc) const {
26 return A(0, 0, arc.weight, arc.nextstate);
29 uint64 Properties(uint64 props) const {
30 props &= ~kNotAcceptor;
32 props &= ~kNoIEpsilons & ~kNoOEpsilons & ~kNoEpsilons;
33 props |= kIEpsilons | kOEpsilons | kEpsilons;
34 props &= ~kNotILabelSorted & ~kNotOLabelSorted;
35 props |= kILabelSorted | kOLabelSorted;
39 MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
41 MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
43 MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
46 // Generic - no lookahead.
48 void LookAheadCompose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
49 MutableFst<Arc> *ofst) {
50 Compose(ifst1, ifst2, ofst);
53 // Specialized and epsilon olabel acyclic - lookahead.
54 void LookAheadCompose(const Fst<StdArc> &ifst1, const Fst<StdArc> &ifst2,
55 MutableFst<StdArc> *ofst) {
56 std::vector<StdArc::StateId> order;
58 TopOrderVisitor<StdArc> visitor(&order, &acyclic);
59 DfsVisit(ifst1, &visitor, OutputEpsilonArcFilter<StdArc>());
60 if (acyclic) { // no ifst1 output epsilon cycles?
61 StdOLabelLookAheadFst lfst1(ifst1);
62 StdVectorFst lfst2(ifst2);
63 LabelLookAheadRelabeler<StdArc>::Relabel(&lfst2, lfst1, true);
64 Compose(lfst1, lfst2, ofst);
66 Compose(ifst1, ifst2, ofst);
70 // This class tests a variety of identities and properties that must
71 // hold for various algorithms on weighted FSTs.
72 template <class Arc, class WeightGenerator>
73 class WeightedTester {
75 typedef typename Arc::Label Label;
76 typedef typename Arc::StateId StateId;
77 typedef typename Arc::Weight Weight;
79 WeightedTester(time_t seed, const Fst<Arc> &zero_fst, const Fst<Arc> &one_fst,
80 const Fst<Arc> &univ_fst, WeightGenerator *weight_generator)
85 weight_generator_(weight_generator) {}
87 void Test(const Fst<Arc> &T1, const Fst<Arc> &T2, const Fst<Arc> &T3) {
88 TestRational(T1, T2, T3);
90 TestCompose(T1, T2, T3);
97 // Tests rational operations with identities
98 void TestRational(const Fst<Arc> &T1, const Fst<Arc> &T2,
101 VLOG(1) << "Check destructive and delayed union are equivalent.";
102 VectorFst<Arc> U1(T1);
104 UnionFst<Arc> U2(T1, T2);
105 CHECK(Equiv(U1, U2));
109 VLOG(1) << "Check destructive and delayed concatenation are equivalent.";
110 VectorFst<Arc> C1(T1);
112 ConcatFst<Arc> C2(T1, T2);
113 CHECK(Equiv(C1, C2));
114 VectorFst<Arc> C3(T2);
116 CHECK(Equiv(C3, C2));
120 VLOG(1) << "Check destructive and delayed closure* are equivalent.";
121 VectorFst<Arc> C1(T1);
122 Closure(&C1, CLOSURE_STAR);
123 ClosureFst<Arc> C2(T1, CLOSURE_STAR);
124 CHECK(Equiv(C1, C2));
128 VLOG(1) << "Check destructive and delayed closure+ are equivalent.";
129 VectorFst<Arc> C1(T1);
130 Closure(&C1, CLOSURE_PLUS);
131 ClosureFst<Arc> C2(T1, CLOSURE_PLUS);
132 CHECK(Equiv(C1, C2));
136 VLOG(1) << "Check union is associative (destructive).";
137 VectorFst<Arc> U1(T1);
141 VectorFst<Arc> U3(T2);
143 VectorFst<Arc> U4(T1);
146 CHECK(Equiv(U1, U4));
150 VLOG(1) << "Check union is associative (delayed).";
151 UnionFst<Arc> U1(T1, T2);
152 UnionFst<Arc> U2(U1, T3);
154 UnionFst<Arc> U3(T2, T3);
155 UnionFst<Arc> U4(T1, U3);
157 CHECK(Equiv(U2, U4));
161 VLOG(1) << "Check union is associative (destructive delayed).";
162 UnionFst<Arc> U1(T1, T2);
165 UnionFst<Arc> U3(T2, T3);
166 UnionFst<Arc> U4(T1, U3);
168 CHECK(Equiv(U1, U4));
172 VLOG(1) << "Check concatenation is associative (destructive).";
173 VectorFst<Arc> C1(T1);
177 VectorFst<Arc> C3(T2);
179 VectorFst<Arc> C4(T1);
182 CHECK(Equiv(C1, C4));
186 VLOG(1) << "Check concatenation is associative (delayed).";
187 ConcatFst<Arc> C1(T1, T2);
188 ConcatFst<Arc> C2(C1, T3);
190 ConcatFst<Arc> C3(T2, T3);
191 ConcatFst<Arc> C4(T1, C3);
193 CHECK(Equiv(C2, C4));
197 VLOG(1) << "Check concatenation is associative (destructive delayed).";
198 ConcatFst<Arc> C1(T1, T2);
201 ConcatFst<Arc> C3(T2, T3);
202 ConcatFst<Arc> C4(T1, C3);
204 CHECK(Equiv(C1, C4));
207 if (Weight::Properties() & kLeftSemiring) {
208 VLOG(1) << "Check concatenation left distributes"
209 << " over union (destructive).";
211 VectorFst<Arc> U1(T1);
213 VectorFst<Arc> C1(T3);
216 VectorFst<Arc> C2(T3);
218 VectorFst<Arc> C3(T3);
220 VectorFst<Arc> U2(C2);
223 CHECK(Equiv(C1, U2));
226 if (Weight::Properties() & kRightSemiring) {
227 VLOG(1) << "Check concatenation right distributes"
228 << " over union (destructive).";
229 VectorFst<Arc> U1(T1);
231 VectorFst<Arc> C1(U1);
234 VectorFst<Arc> C2(T1);
236 VectorFst<Arc> C3(T2);
238 VectorFst<Arc> U2(C2);
241 CHECK(Equiv(C1, U2));
244 if (Weight::Properties() & kLeftSemiring) {
245 VLOG(1) << "Check concatenation left distributes over union (delayed).";
246 UnionFst<Arc> U1(T1, T2);
247 ConcatFst<Arc> C1(T3, U1);
249 ConcatFst<Arc> C2(T3, T1);
250 ConcatFst<Arc> C3(T3, T2);
251 UnionFst<Arc> U2(C2, C3);
253 CHECK(Equiv(C1, U2));
256 if (Weight::Properties() & kRightSemiring) {
257 VLOG(1) << "Check concatenation right distributes over union (delayed).";
258 UnionFst<Arc> U1(T1, T2);
259 ConcatFst<Arc> C1(U1, T3);
261 ConcatFst<Arc> C2(T1, T3);
262 ConcatFst<Arc> C3(T2, T3);
263 UnionFst<Arc> U2(C2, C3);
265 CHECK(Equiv(C1, U2));
268 if (Weight::Properties() & kLeftSemiring) {
269 VLOG(1) << "Check T T* == T+ (destructive).";
270 VectorFst<Arc> S(T1);
271 Closure(&S, CLOSURE_STAR);
272 VectorFst<Arc> C(T1);
275 VectorFst<Arc> P(T1);
276 Closure(&P, CLOSURE_PLUS);
281 if (Weight::Properties() & kRightSemiring) {
282 VLOG(1) << "Check T* T == T+ (destructive).";
283 VectorFst<Arc> S(T1);
284 Closure(&S, CLOSURE_STAR);
288 VectorFst<Arc> P(T1);
289 Closure(&P, CLOSURE_PLUS);
294 if (Weight::Properties() & kLeftSemiring) {
295 VLOG(1) << "Check T T* == T+ (delayed).";
296 ClosureFst<Arc> S(T1, CLOSURE_STAR);
297 ConcatFst<Arc> C(T1, S);
299 ClosureFst<Arc> P(T1, CLOSURE_PLUS);
304 if (Weight::Properties() & kRightSemiring) {
305 VLOG(1) << "Check T* T == T+ (delayed).";
306 ClosureFst<Arc> S(T1, CLOSURE_STAR);
307 ConcatFst<Arc> C(S, T1);
309 ClosureFst<Arc> P(T1, CLOSURE_PLUS);
315 // Tests map-based operations.
316 void TestMap(const Fst<Arc> &T) {
318 VLOG(1) << "Check destructive and delayed projection are equivalent.";
319 VectorFst<Arc> P1(T);
320 Project(&P1, PROJECT_INPUT);
321 ProjectFst<Arc> P2(T, PROJECT_INPUT);
322 CHECK(Equiv(P1, P2));
326 VLOG(1) << "Check destructive and delayed inversion are equivalent.";
327 VectorFst<Arc> I1(T);
329 InvertFst<Arc> I2(T);
330 CHECK(Equiv(I1, I2));
334 VLOG(1) << "Check Pi_1(T) = Pi_2(T^-1) (destructive).";
335 VectorFst<Arc> P1(T);
336 VectorFst<Arc> I1(T);
337 Project(&P1, PROJECT_INPUT);
339 Project(&I1, PROJECT_OUTPUT);
340 CHECK(Equiv(P1, I1));
344 VLOG(1) << "Check Pi_2(T) = Pi_1(T^-1) (destructive).";
345 VectorFst<Arc> P1(T);
346 VectorFst<Arc> I1(T);
347 Project(&P1, PROJECT_OUTPUT);
349 Project(&I1, PROJECT_INPUT);
350 CHECK(Equiv(P1, I1));
354 VLOG(1) << "Check Pi_1(T) = Pi_2(T^-1) (delayed).";
355 ProjectFst<Arc> P1(T, PROJECT_INPUT);
356 InvertFst<Arc> I1(T);
357 ProjectFst<Arc> P2(I1, PROJECT_OUTPUT);
358 CHECK(Equiv(P1, P2));
362 VLOG(1) << "Check Pi_2(T) = Pi_1(T^-1) (delayed).";
363 ProjectFst<Arc> P1(T, PROJECT_OUTPUT);
364 InvertFst<Arc> I1(T);
365 ProjectFst<Arc> P2(I1, PROJECT_INPUT);
366 CHECK(Equiv(P1, P2));
370 VLOG(1) << "Check destructive relabeling";
371 static const int kNumLabels = 10;
372 // set up relabeling pairs
373 std::vector<Label> labelset(kNumLabels);
374 for (size_t i = 0; i < kNumLabels; ++i) labelset[i] = i;
375 for (size_t i = 0; i < kNumLabels; ++i) {
377 swap(labelset[i], labelset[rand() % kNumLabels]);
380 std::vector<std::pair<Label, Label>> ipairs1(kNumLabels);
381 std::vector<std::pair<Label, Label>> opairs1(kNumLabels);
382 for (size_t i = 0; i < kNumLabels; ++i) {
383 ipairs1[i] = std::make_pair(i, labelset[i]);
384 opairs1[i] = std::make_pair(labelset[i], i);
387 Relabel(&R, ipairs1, opairs1);
389 std::vector<std::pair<Label, Label>> ipairs2(kNumLabels);
390 std::vector<std::pair<Label, Label>> opairs2(kNumLabels);
391 for (size_t i = 0; i < kNumLabels; ++i) {
392 ipairs2[i] = std::make_pair(labelset[i], i);
393 opairs2[i] = std::make_pair(i, labelset[i]);
395 Relabel(&R, ipairs2, opairs2);
398 VLOG(1) << "Check on-the-fly relabeling";
399 RelabelFst<Arc> Rdelay(T, ipairs1, opairs1);
401 RelabelFst<Arc> RRdelay(Rdelay, ipairs2, opairs2);
402 CHECK(Equiv(RRdelay, T));
406 VLOG(1) << "Check encoding/decoding (destructive).";
408 uint32 encode_props = 0;
409 if (rand() % 2) encode_props |= kEncodeLabels;
410 if (rand() % 2) encode_props |= kEncodeWeights;
411 EncodeMapper<Arc> encoder(encode_props, ENCODE);
412 Encode(&D, &encoder);
418 VLOG(1) << "Check encoding/decoding (delayed).";
419 uint32 encode_props = 0;
420 if (rand() % 2) encode_props |= kEncodeLabels;
421 if (rand() % 2) encode_props |= kEncodeWeights;
422 EncodeMapper<Arc> encoder(encode_props, ENCODE);
423 EncodeFst<Arc> E(T, &encoder);
424 VectorFst<Arc> Encoded(E);
425 DecodeFst<Arc> D(Encoded, encoder);
430 VLOG(1) << "Check gallic mappers (constructive).";
431 ToGallicMapper<Arc> to_mapper;
432 FromGallicMapper<Arc> from_mapper;
433 VectorFst<GallicArc<Arc>> G;
435 ArcMap(T, &G, to_mapper);
436 ArcMap(G, &F, from_mapper);
441 VLOG(1) << "Check gallic mappers (delayed).";
442 ToGallicMapper<Arc> to_mapper;
443 FromGallicMapper<Arc> from_mapper;
444 ArcMapFst<Arc, GallicArc<Arc>, ToGallicMapper<Arc>> G(T, to_mapper);
445 ArcMapFst<GallicArc<Arc>, Arc, FromGallicMapper<Arc>> F(G, from_mapper);
450 // Tests compose-based operations.
451 void TestCompose(const Fst<Arc> &T1, const Fst<Arc> &T2, const Fst<Arc> &T3) {
452 if (!(Weight::Properties() & kCommutative)) return;
454 VectorFst<Arc> S1(T1);
455 VectorFst<Arc> S2(T2);
456 VectorFst<Arc> S3(T3);
458 ILabelCompare<Arc> icomp;
459 OLabelCompare<Arc> ocomp;
466 VLOG(1) << "Check composition is associative.";
467 ComposeFst<Arc> C1(S1, S2);
468 ComposeFst<Arc> C2(C1, S3);
469 ComposeFst<Arc> C3(S2, S3);
470 ComposeFst<Arc> C4(S1, C3);
472 CHECK(Equiv(C2, C4));
476 VLOG(1) << "Check composition left distributes over union.";
477 UnionFst<Arc> U1(S2, S3);
478 ComposeFst<Arc> C1(S1, U1);
480 ComposeFst<Arc> C2(S1, S2);
481 ComposeFst<Arc> C3(S1, S3);
482 UnionFst<Arc> U2(C2, C3);
484 CHECK(Equiv(C1, U2));
488 VLOG(1) << "Check composition right distributes over union.";
489 UnionFst<Arc> U1(S1, S2);
490 ComposeFst<Arc> C1(U1, S3);
492 ComposeFst<Arc> C2(S1, S3);
493 ComposeFst<Arc> C3(S2, S3);
494 UnionFst<Arc> U2(C2, C3);
496 CHECK(Equiv(C1, U2));
499 VectorFst<Arc> A1(S1);
500 VectorFst<Arc> A2(S2);
501 VectorFst<Arc> A3(S3);
502 Project(&A1, PROJECT_OUTPUT);
503 Project(&A2, PROJECT_INPUT);
504 Project(&A3, PROJECT_INPUT);
507 VLOG(1) << "Check intersection is commutative.";
508 IntersectFst<Arc> I1(A1, A2);
509 IntersectFst<Arc> I2(A2, A1);
510 CHECK(Equiv(I1, I2));
514 VLOG(1) << "Check all epsilon filters leads to equivalent results.";
515 typedef Matcher<Fst<Arc>> M;
516 ComposeFst<Arc> C1(S1, S2);
518 S1, S2, ComposeFstOptions<Arc, M, AltSequenceComposeFilter<M>>());
519 ComposeFst<Arc> C3(S1, S2,
520 ComposeFstOptions<Arc, M, MatchComposeFilter<M>>());
522 CHECK(Equiv(C1, C2));
523 CHECK(Equiv(C1, C3));
525 if ((Weight::Properties() & kIdempotent) ||
526 S1.Properties(kNoOEpsilons, false) ||
527 S2.Properties(kNoIEpsilons, false)) {
529 S1, S2, ComposeFstOptions<Arc, M, TrivialComposeFilter<M>>());
530 CHECK(Equiv(C1, C4));
533 if (S1.Properties(kNoOEpsilons, false) &&
534 S2.Properties(kNoIEpsilons, false)) {
535 ComposeFst<Arc> C5(S1, S2,
536 ComposeFstOptions<Arc, M, NullComposeFilter<M>>());
537 CHECK(Equiv(C1, C5));
542 VLOG(1) << "Check look-ahead filters lead to equivalent results.";
543 VectorFst<Arc> C1, C2;
544 Compose(S1, S2, &C1);
545 LookAheadCompose(S1, S2, &C2);
546 CHECK(Equiv(C1, C2));
550 // Tests sorting operations
551 void TestSort(const Fst<Arc> &T) {
552 ILabelCompare<Arc> icomp;
553 OLabelCompare<Arc> ocomp;
556 VLOG(1) << "Check arc sorted Fst is equivalent to its input.";
557 VectorFst<Arc> S1(T);
563 VLOG(1) << "Check destructive and delayed arcsort are equivalent.";
564 VectorFst<Arc> S1(T);
566 ArcSortFst<Arc, ILabelCompare<Arc>> S2(T, icomp);
567 CHECK(Equiv(S1, S2));
571 VLOG(1) << "Check ilabel sorting vs. olabel sorting with inversions.";
572 VectorFst<Arc> S1(T);
573 VectorFst<Arc> S2(T);
578 CHECK(Equiv(S1, S2));
582 VLOG(1) << "Check topologically sorted Fst is equivalent to its input.";
583 VectorFst<Arc> S1(T);
589 VLOG(1) << "Check reverse(reverse(T)) = T";
590 for (int i = 0; i < 2; ++i) {
591 VectorFst<ReverseArc<Arc>> R1;
593 bool require_superinitial = i == 1;
594 Reverse(T, &R1, require_superinitial);
595 Reverse(R1, &R2, require_superinitial);
601 // Tests optimization operations
602 void TestOptimize(const Fst<Arc> &T) {
603 uint64 tprops = T.Properties(kFstProperties, true);
604 uint64 wprops = Weight::Properties();
607 Project(&A, PROJECT_INPUT);
609 VLOG(1) << "Check connected FST is equivalent to its input.";
610 VectorFst<Arc> C1(T);
615 if ((wprops & kSemiring) == kSemiring &&
616 (tprops & kAcyclic || wprops & kIdempotent)) {
617 VLOG(1) << "Check epsilon-removed FST is equivalent to its input.";
618 VectorFst<Arc> R1(T);
622 VLOG(1) << "Check destructive and delayed epsilon removal"
623 << "are equivalent.";
624 RmEpsilonFst<Arc> R2(T);
625 CHECK(Equiv(R1, R2));
627 VLOG(1) << "Check an FST with a large proportion"
628 << " of epsilon transitions:";
629 // Maps all transitions of T to epsilon-transitions and append
630 // a non-epsilon transition.
632 ArcMap(T, &U, EpsMapper<Arc>());
634 V.SetStart(V.AddState());
635 Arc arc(1, 1, Weight::One(), V.AddState());
636 V.AddArc(V.Start(), arc);
637 V.SetFinal(arc.nextstate, Weight::One());
639 // Check that epsilon-removal preserves the shortest-distance
640 // from the initial state to the final states.
641 std::vector<Weight> d;
642 ShortestDistance(U, &d, true);
643 Weight w = U.Start() < d.size() ? d[U.Start()] : Weight::Zero();
644 VectorFst<Arc> U1(U);
646 ShortestDistance(U1, &d, true);
647 Weight w1 = U1.Start() < d.size() ? d[U1.Start()] : Weight::Zero();
648 CHECK(ApproxEqual(w, w1, kTestDelta));
649 RmEpsilonFst<Arc> U2(U);
650 ShortestDistance(U2, &d, true);
651 Weight w2 = U2.Start() < d.size() ? d[U2.Start()] : Weight::Zero();
652 CHECK(ApproxEqual(w, w2, kTestDelta));
655 if ((wprops & kSemiring) == kSemiring && tprops & kAcyclic) {
656 VLOG(1) << "Check determinized FSA is equivalent to its input.";
657 DeterminizeFst<Arc> D(A);
661 VLOG(1) << "Check determinized FST is equivalent to its input.";
662 DeterminizeFstOptions<Arc> opts;
663 opts.type = DETERMINIZE_NONFUNCTIONAL;
664 DeterminizeFst<Arc> DT(T, opts);
668 if ((wprops & (kPath | kCommutative)) == (kPath | kCommutative)) {
669 VLOG(1) << "Check pruning in determinization";
671 Weight threshold = (*weight_generator_)();
672 DeterminizeOptions<Arc> opts;
673 opts.weight_threshold = threshold;
674 Determinize(A, &P, opts);
675 CHECK(P.Properties(kIDeterministic, true));
676 CHECK(PruneEquiv(A, P, threshold));
679 if ((wprops & kPath) == kPath) {
680 VLOG(1) << "Check min-determinization";
682 // Ensures no input epsilons
684 std::vector<std::pair<Label, Label>> ipairs, opairs;
685 ipairs.push_back(std::pair<Label, Label>(0, 1));
686 Relabel(&R, ipairs, opairs);
689 DeterminizeOptions<Arc> opts;
690 opts.type = DETERMINIZE_DISAMBIGUATE;
691 Determinize(R, &M, opts);
692 CHECK(M.Properties(kIDeterministic, true));
693 CHECK(MinRelated(M, R));
698 VLOG(1) << "Check size(min(det(A))) <= size(det(A))"
699 << " and min(det(A)) equiv det(A)";
704 CHECK(M.NumStates() <= n);
708 if (n && (wprops & kIdempotent) == kIdempotent &&
709 A.Properties(kNoEpsilons, true)) {
710 VLOG(1) << "Check that Revuz's algorithm leads to the"
711 << " same number of states as Brozozowski's algorithm";
713 // Skip test if A is the empty machine or contains epsilons or
714 // if the semiring is not idempotent (to avoid floating point
719 DeterminizeFst<Arc> DR(R);
722 DeterminizeFst<Arc> DRD(RD);
723 VectorFst<Arc> M(DRD);
724 CHECK_EQ(n + 1, M.NumStates()); // Accounts for the epsilon transition
725 // to the initial state
729 if ((wprops & kSemiring) == kSemiring && tprops & kAcyclic) {
730 VLOG(1) << "Check disambiguated FSA is equivalent to its input.";
731 VectorFst<Arc> R(A), D;
735 VLOG(1) << "Check disambiguated FSA is unambiguous";
736 CHECK(Unambiguous(D));
738 /* TODO(riley): find out why this fails
739 if ((wprops & (kPath | kCommutative)) == (kPath | kCommutative)) {
740 VLOG(1) << "Check pruning in disambiguation";
742 Weight threshold = (*weight_generator_)();
743 DisambiguateOptions<Arc> opts;
744 opts.weight_threshold = threshold;
745 Disambiguate(R, &P, opts);
746 CHECK(Unambiguous(P));
747 CHECK(PruneEquiv(A, P, threshold));
752 if (Arc::Type() == LogArc::Type() || Arc::Type() == StdArc::Type()) {
753 VLOG(1) << "Check reweight(T) equiv T";
754 std::vector<Weight> potential;
755 VectorFst<Arc> RI(T);
756 VectorFst<Arc> RF(T);
757 while (potential.size() < RI.NumStates())
758 potential.push_back((*weight_generator_)());
760 Reweight(&RI, potential, REWEIGHT_TO_INITIAL);
763 Reweight(&RF, potential, REWEIGHT_TO_FINAL);
767 if ((wprops & kIdempotent) || (tprops & kAcyclic)) {
768 VLOG(1) << "Check pushed FST is equivalent to input FST.";
769 // Pushing towards the final state.
770 if (wprops & kRightSemiring) {
772 Push<Arc, REWEIGHT_TO_FINAL>(T, &P1, kPushLabels);
776 Push<Arc, REWEIGHT_TO_FINAL>(T, &P2, kPushWeights);
780 Push<Arc, REWEIGHT_TO_FINAL>(T, &P3, kPushLabels | kPushWeights);
784 // Pushing towards the initial state.
785 if (wprops & kLeftSemiring) {
787 Push<Arc, REWEIGHT_TO_INITIAL>(T, &P1, kPushLabels);
791 Push<Arc, REWEIGHT_TO_INITIAL>(T, &P2, kPushWeights);
794 Push<Arc, REWEIGHT_TO_INITIAL>(T, &P3, kPushLabels | kPushWeights);
799 if ((wprops & (kPath | kCommutative)) == (kPath | kCommutative)) {
800 VLOG(1) << "Check pruning algorithm";
802 VLOG(1) << "Check equiv. of constructive and destructive algorithms";
803 Weight thresold = (*weight_generator_)();
804 VectorFst<Arc> P1(T);
805 Prune(&P1, thresold);
807 Prune(T, &P2, thresold);
808 CHECK(Equiv(P1, P2));
812 VLOG(1) << "Check prune(reverse) equiv reverse(prune)";
813 Weight thresold = (*weight_generator_)();
814 VectorFst<ReverseArc<Arc>> R;
815 VectorFst<Arc> P1(T);
817 Prune(&P1, thresold);
819 Prune(&R, thresold.Reverse());
821 CHECK(Equiv(P1, P2));
824 VLOG(1) << "Check: ShortestDistance(A - prune(A))"
825 << " > ShortestDistance(A) times Threshold";
826 Weight threshold = (*weight_generator_)();
828 Prune(A, &P, threshold);
829 CHECK(PruneEquiv(A, P, threshold));
832 if (tprops & kAcyclic) {
833 VLOG(1) << "Check synchronize(T) equiv T";
834 SynchronizeFst<Arc> S(T);
839 // Tests search operations
840 void TestSearch(const Fst<Arc> &T) {
841 uint64 wprops = Weight::Properties();
844 Project(&A, PROJECT_INPUT);
846 if ((wprops & (kPath | kRightSemiring)) == (kPath | kRightSemiring)) {
847 VLOG(1) << "Check 1-best weight.";
849 ShortestPath(T, &path);
850 Weight tsum = ShortestDistance(T);
851 Weight psum = ShortestDistance(path);
852 CHECK(ApproxEqual(tsum, psum, kTestDelta));
855 if ((wprops & (kPath | kSemiring)) == (kPath | kSemiring)) {
856 VLOG(1) << "Check n-best weights";
859 int nshortest = rand() % kNumRandomShortestPaths + 2;
860 VectorFst<Arc> paths;
861 ShortestPath(R, &paths, nshortest, true, false, Weight::Zero(),
863 std::vector<Weight> distance;
864 ShortestDistance(paths, &distance, true);
865 StateId pstart = paths.Start();
866 if (pstart != kNoStateId) {
867 ArcIterator<Fst<Arc>> piter(paths, pstart);
868 for (; !piter.Done(); piter.Next()) {
869 StateId s = piter.Value().nextstate;
870 Weight nsum = s < distance.size()
871 ? Times(piter.Value().weight, distance[s])
874 ShortestPath(R, &path);
875 Weight dsum = ShortestDistance(path);
876 CHECK(ApproxEqual(nsum, dsum, kTestDelta));
877 ArcMap(&path, RmWeightMapper<Arc>());
879 Difference(R, path, &S);
886 // Tests if two FSTS are equivalent by checking if random
887 // strings from one FST are transduced the same by both FSTs.
889 bool Equiv(const Fst<A> &fst1, const Fst<A> &fst2) {
890 VLOG(1) << "Check FSTs for sanity (including property bits).";
894 // Ensures seed used once per instantiation.
895 static UniformArcSelector<A> uniform_selector(seed_);
896 RandGenOptions<UniformArcSelector<A>> opts(uniform_selector,
898 return RandEquivalent(fst1, fst2, kNumRandomPaths, kTestDelta, opts);
901 // Tests FSA is unambiguous
902 bool Unambiguous(const Fst<Arc> &fst) {
903 VectorFst<StdArc> sfst, dfst;
904 VectorFst<LogArc> lfst1, lfst2;
905 Map(fst, &sfst, RmWeightMapper<Arc, StdArc>());
906 Determinize(sfst, &dfst);
907 Map(fst, &lfst1, RmWeightMapper<Arc, LogArc>());
908 Map(dfst, &lfst2, RmWeightMapper<StdArc, LogArc>());
909 return Equiv(lfst1, lfst2);
912 // Ensures input-epsilon free transducers fst1 and fst2 have the
913 // same domain and that for each string pair '(is, os)' in fst1,
914 // '(is, os)' is the minimum weight match to 'is' in fst2.
916 bool MinRelated(const Fst<A> &fst1, const Fst<A> &fst2) {
918 VectorFst<Arc> P1(fst1), P2(fst2);
919 Project(&P1, PROJECT_INPUT);
920 Project(&P2, PROJECT_INPUT);
921 if (!Equiv(P1, P2)) {
922 LOG(ERROR) << "Inputs not equivalent";
926 // Ensures seed used once per instantiation.
927 static UniformArcSelector<A> uniform_selector(seed_);
928 RandGenOptions<UniformArcSelector<A>> opts(uniform_selector,
931 VectorFst<Arc> path, paths1, paths2;
932 for (ssize_t n = 0; n < kNumRandomPaths; ++n) {
933 RandGen(fst1, &path, opts);
935 Map(&path, RmWeightMapper<Arc>());
936 Compose(path, fst2, &paths1);
937 Weight sum1 = ShortestDistance(paths1);
938 Compose(paths1, path, &paths2);
939 Weight sum2 = ShortestDistance(paths2);
940 if (!ApproxEqual(Plus(sum1, sum2), sum2, kTestDelta)) {
941 LOG(ERROR) << "Sums not equivalent: " << sum1 << " " << sum2;
948 // Tests ShortestDistance(A - P) >=
949 // ShortestDistance(A) times Threshold.
951 bool PruneEquiv(const Fst<A> &fst, const Fst<A> &pfst, Weight threshold) {
952 VLOG(1) << "Check FSTs for sanity (including property bits).";
956 DifferenceFst<Arc> D(fst, DeterminizeFst<Arc>(RmEpsilonFst<Arc>(
957 ArcMapFst<Arc, Arc, RmWeightMapper<Arc>>(
958 pfst, RmWeightMapper<Arc>()))));
959 Weight sum1 = Times(ShortestDistance(fst), threshold);
960 Weight sum2 = ShortestDistance(D);
961 return ApproxEqual(Plus(sum1, sum2), sum1, kTestDelta);
966 // FST with no states
967 VectorFst<Arc> zero_fst_;
968 // FST with one state that accepts epsilon.
969 VectorFst<Arc> one_fst_;
970 // FST with one state that accepts all strings.
971 VectorFst<Arc> univ_fst_;
972 // Generates weights used in testing.
973 WeightGenerator *weight_generator_;
974 // Maximum random path length.
975 static const int kRandomPathLength;
976 // Number of random paths to explore.
977 static const int kNumRandomPaths;
978 // Maximum number of nshortest paths.
979 static const int kNumRandomShortestPaths;
980 // Maximum number of nshortest states.
981 static const int kNumShortestStates;
982 // Delta for equivalence tests.
983 static const float kTestDelta;
985 WeightedTester(const WeightedTester &) = delete;
986 WeightedTester &operator=(const WeightedTester &) = delete;
989 template <class A, class WG>
990 const int WeightedTester<A, WG>::kRandomPathLength = 25;
992 template <class A, class WG>
993 const int WeightedTester<A, WG>::kNumRandomPaths = 100;
995 template <class A, class WG>
996 const int WeightedTester<A, WG>::kNumRandomShortestPaths = 100;
998 template <class A, class WG>
999 const int WeightedTester<A, WG>::kNumShortestStates = 10000;
1001 template <class A, class WG>
1002 const float WeightedTester<A, WG>::kTestDelta = .05;
1004 // This class tests a variety of identities and properties that must
1005 // hold for various algorithms on unweighted FSAs and that are not tested
1006 // by WeightedTester. Only the specialization does anything interesting.
1007 template <class Arc>
1008 class UnweightedTester {
1010 UnweightedTester(const Fst<Arc> &zero_fsa, const Fst<Arc> &one_fsa,
1011 const Fst<Arc> &univ_fsa) {}
1013 void Test(const Fst<Arc> &A1, const Fst<Arc> &A2, const Fst<Arc> &A3) {}
1016 // Specialization for StdArc. This should work for any commutative,
1017 // idempotent semiring when restricted to the unweighted case
1018 // (being isomorphic to the boolean semiring).
1020 class UnweightedTester<StdArc> {
1023 typedef Arc::Label Label;
1024 typedef Arc::StateId StateId;
1025 typedef Arc::Weight Weight;
1027 UnweightedTester(const Fst<Arc> &zero_fsa, const Fst<Arc> &one_fsa,
1028 const Fst<Arc> &univ_fsa)
1029 : zero_fsa_(zero_fsa), one_fsa_(one_fsa), univ_fsa_(univ_fsa) {}
1031 void Test(const Fst<Arc> &A1, const Fst<Arc> &A2, const Fst<Arc> &A3) {
1032 TestRational(A1, A2, A3);
1033 TestIntersect(A1, A2, A3);
1038 // Tests rational operations with identities
1039 void TestRational(const Fst<Arc> &A1, const Fst<Arc> &A2,
1040 const Fst<Arc> &A3) {
1042 VLOG(1) << "Check the union contains its arguments (destructive).";
1043 VectorFst<Arc> U(A1);
1046 CHECK(Subset(A1, U));
1047 CHECK(Subset(A2, U));
1051 VLOG(1) << "Check the union contains its arguments (delayed).";
1052 UnionFst<Arc> U(A1, A2);
1054 CHECK(Subset(A1, U));
1055 CHECK(Subset(A2, U));
1059 VLOG(1) << "Check if A^n c A* (destructive).";
1060 VectorFst<Arc> C(one_fsa_);
1062 for (int i = 0; i < n; ++i) Concat(&C, A1);
1064 VectorFst<Arc> S(A1);
1065 Closure(&S, CLOSURE_STAR);
1066 CHECK(Subset(C, S));
1070 VLOG(1) << "Check if A^n c A* (delayed).";
1072 Fst<Arc> *C = new VectorFst<Arc>(one_fsa_);
1073 for (int i = 0; i < n; ++i) {
1074 ConcatFst<Arc> *F = new ConcatFst<Arc>(*C, A1);
1078 ClosureFst<Arc> S(A1, CLOSURE_STAR);
1079 CHECK(Subset(*C, S));
1084 // Tests intersect-based operations.
1085 void TestIntersect(const Fst<Arc> &A1, const Fst<Arc> &A2,
1086 const Fst<Arc> &A3) {
1087 VectorFst<Arc> S1(A1);
1088 VectorFst<Arc> S2(A2);
1089 VectorFst<Arc> S3(A3);
1091 ILabelCompare<Arc> comp;
1098 VLOG(1) << "Check the intersection is contained in its arguments.";
1099 IntersectFst<Arc> I1(S1, S2);
1100 CHECK(Subset(I1, S1));
1101 CHECK(Subset(I1, S2));
1105 VLOG(1) << "Check union distributes over intersection.";
1106 IntersectFst<Arc> I1(S1, S2);
1107 UnionFst<Arc> U1(I1, S3);
1109 UnionFst<Arc> U2(S1, S3);
1110 UnionFst<Arc> U3(S2, S3);
1111 ArcSortFst<Arc, ILabelCompare<Arc>> S4(U3, comp);
1112 IntersectFst<Arc> I2(U2, S4);
1114 CHECK(Equiv(U1, I2));
1119 Complement(S1, &C1);
1120 Complement(S2, &C2);
1125 VLOG(1) << "Check S U S' = Sigma*";
1126 UnionFst<Arc> U(S1, C1);
1127 CHECK(Equiv(U, univ_fsa_));
1131 VLOG(1) << "Check S n S' = {}";
1132 IntersectFst<Arc> I(S1, C1);
1133 CHECK(Equiv(I, zero_fsa_));
1137 VLOG(1) << "Check (S1' U S2') == (S1 n S2)'";
1138 UnionFst<Arc> U(C1, C2);
1140 IntersectFst<Arc> I(S1, S2);
1143 CHECK(Equiv(U, C3));
1147 VLOG(1) << "Check (S1' n S2') == (S1 U S2)'";
1148 IntersectFst<Arc> I(C1, C2);
1150 UnionFst<Arc> U(S1, S2);
1153 CHECK(Equiv(I, C3));
1157 // Tests optimization operations
1158 void TestOptimize(const Fst<Arc> &A) {
1160 VLOG(1) << "Check determinized FSA is equivalent to its input.";
1161 DeterminizeFst<Arc> D(A);
1166 VLOG(1) << "Check disambiguated FSA is equivalent to its input.";
1167 VectorFst<Arc> R(A), D;
1170 Disambiguate(R, &D);
1175 VLOG(1) << "Check minimized FSA is equivalent to its input.";
1178 RmEpsilonFst<Arc> R(A);
1179 DeterminizeFst<Arc> D(R);
1180 VectorFst<Arc> M(D);
1186 if (n) { // Skip test if A is the empty machine
1187 VLOG(1) << "Check that Hopcroft's and Revuz's algorithms lead to the"
1188 << " same number of states as Brozozowski's algorithm";
1192 DeterminizeFst<Arc> DR(R);
1195 DeterminizeFst<Arc> DRD(RD);
1196 VectorFst<Arc> M(DRD);
1197 CHECK_EQ(n + 1, M.NumStates()); // Accounts for the epsilon transition
1198 // to the initial state
1203 // Tests if two FSAS are equivalent.
1204 bool Equiv(const Fst<Arc> &fsa1, const Fst<Arc> &fsa2) {
1205 VLOG(1) << "Check FSAs for sanity (including property bits).";
1206 CHECK(Verify(fsa1));
1207 CHECK(Verify(fsa2));
1209 VectorFst<Arc> vfsa1(fsa1);
1210 VectorFst<Arc> vfsa2(fsa2);
1213 DeterminizeFst<Arc> dfa1(vfsa1);
1214 DeterminizeFst<Arc> dfa2(vfsa2);
1216 // Test equivalence using union-find algorithm
1217 bool equiv1 = Equivalent(dfa1, dfa2);
1219 // Test equivalence by checking if (S1 - S2) U (S2 - S1) is empty
1220 ILabelCompare<Arc> comp;
1221 VectorFst<Arc> sdfa1(dfa1);
1222 ArcSort(&sdfa1, comp);
1223 VectorFst<Arc> sdfa2(dfa2);
1224 ArcSort(&sdfa2, comp);
1226 DifferenceFst<Arc> dfsa1(sdfa1, sdfa2);
1227 DifferenceFst<Arc> dfsa2(sdfa2, sdfa1);
1229 VectorFst<Arc> ufsa(dfsa1);
1230 Union(&ufsa, dfsa2);
1232 bool equiv2 = ufsa.NumStates() == 0;
1234 // Check two equivalence tests match
1235 CHECK((equiv1 && equiv2) || (!equiv1 && !equiv2));
1240 // Tests if FSA1 is a subset of FSA2 (disregarding weights).
1241 bool Subset(const Fst<Arc> &fsa1, const Fst<Arc> &fsa2) {
1242 VLOG(1) << "Check FSAs (incl. property bits) for sanity";
1243 CHECK(Verify(fsa1));
1244 CHECK(Verify(fsa2));
1246 VectorFst<StdArc> vfsa1;
1247 VectorFst<StdArc> vfsa2;
1250 ILabelCompare<StdArc> comp;
1251 ArcSort(&vfsa1, comp);
1252 ArcSort(&vfsa2, comp);
1253 IntersectFst<StdArc> ifsa(vfsa1, vfsa2);
1254 DeterminizeFst<StdArc> dfa1(vfsa1);
1255 DeterminizeFst<StdArc> dfa2(ifsa);
1256 return Equivalent(dfa1, dfa2);
1259 // Returns complement Fsa
1260 void Complement(const Fst<Arc> &ifsa, MutableFst<Arc> *ofsa) {
1261 RmEpsilonFst<Arc> rfsa(ifsa);
1262 DeterminizeFst<Arc> dfa(rfsa);
1263 DifferenceFst<Arc> cfsa(univ_fsa_, dfa);
1267 // FSA with no states
1268 VectorFst<Arc> zero_fsa_;
1270 // FSA with one state that accepts epsilon.
1271 VectorFst<Arc> one_fsa_;
1273 // FSA with one state that accepts all strings.
1274 VectorFst<Arc> univ_fsa_;
1277 // This class tests a variety of identities and properties that must
1278 // hold for various FST algorithms. It randomly generates FSTs, using
1279 // function object 'weight_generator' to select weights. 'WeightTester'
1280 // and 'UnweightedTester' are then called.
1281 template <class Arc, class WeightGenerator>
1284 typedef typename Arc::Label Label;
1285 typedef typename Arc::StateId StateId;
1286 typedef typename Arc::Weight Weight;
1288 AlgoTester(WeightGenerator generator, int seed)
1289 : weight_generator_(generator) {
1290 one_fst_.AddState();
1291 one_fst_.SetStart(0);
1292 one_fst_.SetFinal(0, Weight::One());
1294 univ_fst_.AddState();
1295 univ_fst_.SetStart(0);
1296 univ_fst_.SetFinal(0, Weight::One());
1297 for (int i = 0; i < kNumRandomLabels; ++i)
1298 univ_fst_.AddArc(0, Arc(i, i, Weight::One(), 0));
1300 weighted_tester_ = new WeightedTester<Arc, WeightGenerator>(
1301 seed, zero_fst_, one_fst_, univ_fst_, &weight_generator_);
1303 unweighted_tester_ =
1304 new UnweightedTester<Arc>(zero_fst_, one_fst_, univ_fst_);
1308 delete weighted_tester_;
1309 delete unweighted_tester_;
1312 void MakeRandFst(MutableFst<Arc> *fst) {
1313 RandFst<Arc, WeightGenerator>(kNumRandomStates, kNumRandomArcs,
1314 kNumRandomLabels, kAcyclicProb,
1315 &weight_generator_, fst);
1319 VLOG(1) << "weight type = " << Weight::Type();
1321 for (int i = 0; i < FLAGS_repeat; ++i) {
1322 // Random transducers
1329 weighted_tester_->Test(T1, T2, T3);
1331 VectorFst<Arc> A1(T1);
1332 VectorFst<Arc> A2(T2);
1333 VectorFst<Arc> A3(T3);
1334 Project(&A1, PROJECT_OUTPUT);
1335 Project(&A2, PROJECT_INPUT);
1336 Project(&A3, PROJECT_INPUT);
1337 ArcMap(&A1, rm_weight_mapper_);
1338 ArcMap(&A2, rm_weight_mapper_);
1339 ArcMap(&A3, rm_weight_mapper_);
1340 unweighted_tester_->Test(A1, A2, A3);
1345 // Generates weights used in testing.
1346 WeightGenerator weight_generator_;
1348 // FST with no states
1349 VectorFst<Arc> zero_fst_;
1351 // FST with one state that accepts epsilon.
1352 VectorFst<Arc> one_fst_;
1354 // FST with one state that accepts all strings.
1355 VectorFst<Arc> univ_fst_;
1357 // Tests weighted FSTs
1358 WeightedTester<Arc, WeightGenerator> *weighted_tester_;
1360 // Tests unweighted FSTs
1361 UnweightedTester<Arc> *unweighted_tester_;
1363 // Mapper to remove weights from an Fst
1364 RmWeightMapper<Arc> rm_weight_mapper_;
1366 // Maximum number of states in random test Fst.
1367 static const int kNumRandomStates;
1369 // Maximum number of arcs in random test Fst.
1370 static const int kNumRandomArcs;
1372 // Number of alternative random labels.
1373 static const int kNumRandomLabels;
1375 // Probability to force an acyclic Fst
1376 static const float kAcyclicProb;
1378 // Maximum random path length.
1379 static const int kRandomPathLength;
1381 // Number of random paths to explore.
1382 static const int kNumRandomPaths;
1384 AlgoTester(const AlgoTester &) = delete;
1385 AlgoTester &operator=(const AlgoTester &) = delete;
1388 template <class A, class G>
1389 const int AlgoTester<A, G>::kNumRandomStates = 10;
1391 template <class A, class G>
1392 const int AlgoTester<A, G>::kNumRandomArcs = 25;
1394 template <class A, class G>
1395 const int AlgoTester<A, G>::kNumRandomLabels = 5;
1397 template <class A, class G>
1398 const float AlgoTester<A, G>::kAcyclicProb = .25;
1400 template <class A, class G>
1401 const int AlgoTester<A, G>::kRandomPathLength = 25;
1403 template <class A, class G>
1404 const int AlgoTester<A, G>::kNumRandomPaths = 100;
1408 #endif // FST_TEST_ALGO_TEST_H_