1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
4 // Regression test for FST weights.
10 #include <fst/expectation-weight.h>
11 #include <fst/float-weight.h>
12 #include <fst/lexicographic-weight.h>
13 #include <fst/power-weight.h>
14 #include <fst/product-weight.h>
15 #include <fst/signed-log-weight.h>
16 #include <fst/sparse-power-weight.h>
17 #include <fst/string-weight.h>
18 #include <fst/union-weight.h>
19 #include "./weight-tester.h"
21 DEFINE_int32(seed, -1, "random seed");
22 DEFINE_int32(repeat, 10000, "number of test repetitions");
27 using fst::ExpectationWeight;
29 using fst::GallicWeight;
30 using fst::LexicographicWeight;
32 using fst::LogWeightTpl;
33 using fst::MinMaxWeight;
34 using fst::MinMaxWeightTpl;
35 using fst::NaturalLess;
36 using fst::PowerWeight;
37 using fst::ProductWeight;
38 using fst::SignedLogWeight;
39 using fst::SignedLogWeightTpl;
40 using fst::SparsePowerWeight;
41 using fst::StringWeight;
42 using fst::STRING_LEFT;
43 using fst::STRING_RIGHT;
44 using fst::TropicalWeight;
45 using fst::TropicalWeightTpl;
46 using fst::UnionWeight;
47 using fst::WeightGenerate;
48 using fst::WeightTester;
51 void TestTemplatedWeights(int repeat) {
52 using TropicalWeightGenerate = WeightGenerate<TropicalWeightTpl<T>>;
53 TropicalWeightGenerate tropical_generate;
54 WeightTester<TropicalWeightTpl<T>, TropicalWeightGenerate> tropical_tester(
56 tropical_tester.Test(repeat);
58 using LogWeightGenerate = WeightGenerate<LogWeightTpl<T>>;
59 LogWeightGenerate log_generate;
60 WeightTester<LogWeightTpl<T>, LogWeightGenerate> log_tester(log_generate);
61 log_tester.Test(repeat);
63 using MinMaxWeightGenerate = WeightGenerate<MinMaxWeightTpl<T>>;
64 MinMaxWeightGenerate minmax_generate(true);
65 WeightTester<MinMaxWeightTpl<T>, MinMaxWeightGenerate> minmax_tester(
67 minmax_tester.Test(repeat);
69 using SignedLogWeightGenerate = WeightGenerate<SignedLogWeightTpl<T>>;
70 SignedLogWeightGenerate signedlog_generate;
71 WeightTester<SignedLogWeightTpl<T>, SignedLogWeightGenerate>
72 signedlog_tester(signedlog_generate);
73 signedlog_tester.Test(repeat);
76 template <class Weight>
77 void TestAdder(int n) {
78 Weight sum = Weight::Zero();
80 for (int i = 0; i < n; ++i) {
81 sum = Plus(sum, Weight::One());
82 adder.Add(Weight::One());
84 CHECK(ApproxEqual(sum, adder.Sum()));
87 template <class Weight>
88 void TestSignedAdder(int n) {
89 Weight sum = Weight::Zero();
91 const Weight minus_one = Minus(Weight::Zero(), Weight::One());
92 for (int i = 0; i < n; ++i) {
93 if (i < n/4 || i > 3*n/4) {
94 sum = Plus(sum, Weight::One());
95 adder.Add(Weight::One());
97 sum = Minus(sum, Weight::One());
101 CHECK(ApproxEqual(sum, adder.Sum()));
106 int main(int argc, char **argv) {
107 std::set_new_handler(FailedNewHandler);
108 SET_FLAGS(argv[0], &argc, &argv, true);
110 LOG(INFO) << "Seed = " << FLAGS_seed;
113 TestTemplatedWeights<float>(FLAGS_repeat);
114 TestTemplatedWeights<double>(FLAGS_repeat);
115 FLAGS_fst_weight_parentheses = "()";
116 TestTemplatedWeights<float>(FLAGS_repeat);
117 TestTemplatedWeights<double>(FLAGS_repeat);
118 FLAGS_fst_weight_parentheses = "";
120 // Makes sure type names for templated weights are consistent.
121 CHECK(TropicalWeight::Type() == "tropical");
122 CHECK(TropicalWeightTpl<double>::Type() != TropicalWeightTpl<float>::Type());
123 CHECK(LogWeight::Type() == "log");
124 CHECK(LogWeightTpl<double>::Type() != LogWeightTpl<float>::Type());
125 TropicalWeightTpl<double> w(15.0);
126 TropicalWeight tw(15.0);
128 TestAdder<TropicalWeight>(1000);
129 TestAdder<LogWeight>(1000);
130 TestSignedAdder<SignedLogWeight>(1000);
134 using LeftStringWeight = StringWeight<int>;
135 using LeftStringWeightGenerate = WeightGenerate<LeftStringWeight>;
136 LeftStringWeightGenerate left_string_generate;
137 WeightTester<LeftStringWeight, LeftStringWeightGenerate> left_string_tester(
138 left_string_generate);
139 left_string_tester.Test(FLAGS_repeat);
141 using RightStringWeight = StringWeight<int, STRING_RIGHT>;
142 using RightStringWeightGenerate = WeightGenerate<RightStringWeight>;
143 RightStringWeightGenerate right_string_generate;
144 WeightTester<RightStringWeight, RightStringWeightGenerate>
145 right_string_tester(right_string_generate);
146 right_string_tester.Test(FLAGS_repeat);
148 // COMPOSITE WEIGHTS AND TESTERS - DEFINITIONS
150 using TropicalGallicWeight = GallicWeight<int, TropicalWeight>;
151 using TropicalGallicWeightGenerate = WeightGenerate<TropicalGallicWeight>;
152 TropicalGallicWeightGenerate tropical_gallic_generate(true);
153 WeightTester<TropicalGallicWeight, TropicalGallicWeightGenerate>
154 tropical_gallic_tester(tropical_gallic_generate);
156 using TropicalGenGallicWeight = GallicWeight<int, TropicalWeight, GALLIC>;
157 using TropicalGenGallicWeightGenerate =
158 WeightGenerate<TropicalGenGallicWeight>;
159 TropicalGenGallicWeightGenerate tropical_gen_gallic_generate(false);
160 WeightTester<TropicalGenGallicWeight, TropicalGenGallicWeightGenerate>
161 tropical_gen_gallic_tester(tropical_gen_gallic_generate);
163 using TropicalProductWeight = ProductWeight<TropicalWeight, TropicalWeight>;
164 using TropicalProductWeightGenerate = WeightGenerate<TropicalProductWeight>;
165 TropicalProductWeightGenerate tropical_product_generate;
166 WeightTester<TropicalProductWeight, TropicalProductWeightGenerate>
167 tropical_product_tester(tropical_product_generate);
169 using TropicalLexicographicWeight =
170 LexicographicWeight<TropicalWeight, TropicalWeight>;
171 using TropicalLexicographicWeightGenerate =
172 WeightGenerate<TropicalLexicographicWeight>;
173 TropicalLexicographicWeightGenerate tropical_lexicographic_generate;
174 WeightTester<TropicalLexicographicWeight,
175 TropicalLexicographicWeightGenerate>
176 tropical_lexicographic_tester(tropical_lexicographic_generate);
178 using TropicalCubeWeight = PowerWeight<TropicalWeight, 3>;
179 using TropicalCubeWeightGenerate = WeightGenerate<TropicalCubeWeight>;
180 TropicalCubeWeightGenerate tropical_cube_generate;
181 WeightTester<TropicalCubeWeight, TropicalCubeWeightGenerate>
182 tropical_cube_tester(tropical_cube_generate);
184 using FirstNestedProductWeight =
185 ProductWeight<TropicalProductWeight, TropicalWeight>;
186 using FirstNestedProductWeightGenerate =
187 WeightGenerate<FirstNestedProductWeight>;
188 FirstNestedProductWeightGenerate first_nested_product_generate;
189 WeightTester<FirstNestedProductWeight, FirstNestedProductWeightGenerate>
190 first_nested_product_tester(first_nested_product_generate);
192 using SecondNestedProductWeight =
193 ProductWeight<TropicalWeight, TropicalProductWeight>;
194 using SecondNestedProductWeightGenerate =
195 WeightGenerate<SecondNestedProductWeight>;
196 SecondNestedProductWeightGenerate second_nested_product_generate;
197 WeightTester<SecondNestedProductWeight, SecondNestedProductWeightGenerate>
198 second_nested_product_tester(second_nested_product_generate);
200 using NestedProductCubeWeight = PowerWeight<FirstNestedProductWeight, 3>;
201 using NestedProductCubeWeightGenerate =
202 WeightGenerate<NestedProductCubeWeight>;
203 NestedProductCubeWeightGenerate nested_product_cube_generate;
204 WeightTester<NestedProductCubeWeight, NestedProductCubeWeightGenerate>
205 nested_product_cube_tester(nested_product_cube_generate);
207 using SparseNestedProductCubeWeight =
208 SparsePowerWeight<NestedProductCubeWeight, size_t>;
209 using SparseNestedProductCubeWeightGenerate =
210 WeightGenerate<SparseNestedProductCubeWeight>;
211 SparseNestedProductCubeWeightGenerate sparse_nested_product_cube_generate;
212 WeightTester<SparseNestedProductCubeWeight,
213 SparseNestedProductCubeWeightGenerate>
214 sparse_nested_product_cube_tester(sparse_nested_product_cube_generate);
216 using LogSparsePowerWeight = SparsePowerWeight<LogWeight, size_t>;
217 using LogSparsePowerWeightGenerate = WeightGenerate<LogSparsePowerWeight>;
218 LogSparsePowerWeightGenerate log_sparse_power_generate;
219 WeightTester<LogSparsePowerWeight, LogSparsePowerWeightGenerate>
220 log_sparse_power_tester(log_sparse_power_generate);
222 using LogLogExpectationWeight = ExpectationWeight<LogWeight, LogWeight>;
223 using LogLogExpectationWeightGenerate =
224 WeightGenerate<LogLogExpectationWeight>;
225 LogLogExpectationWeightGenerate log_log_expectation_generate;
226 WeightTester<LogLogExpectationWeight, LogLogExpectationWeightGenerate>
227 log_log_expectation_tester(log_log_expectation_generate);
229 using LogLogSparseExpectationWeight =
230 ExpectationWeight<LogWeight, LogSparsePowerWeight>;
231 using LogLogSparseExpectationWeightGenerate =
232 WeightGenerate<LogLogSparseExpectationWeight>;
233 LogLogSparseExpectationWeightGenerate log_log_sparse_expectation_generate;
234 WeightTester<LogLogSparseExpectationWeight,
235 LogLogSparseExpectationWeightGenerate>
236 log_log_sparse_expectation_tester(log_log_sparse_expectation_generate);
238 struct UnionWeightOptions {
239 using Compare = NaturalLess<TropicalWeight>;
242 TropicalWeight operator()(const TropicalWeight &w1,
243 const TropicalWeight &w2) const {
248 using ReverseOptions = UnionWeightOptions;
251 using TropicalUnionWeight = UnionWeight<TropicalWeight, UnionWeightOptions>;
252 using TropicalUnionWeightGenerate = WeightGenerate<TropicalUnionWeight>;
253 TropicalUnionWeightGenerate tropical_union_generate;
254 WeightTester<TropicalUnionWeight, TropicalUnionWeightGenerate>
255 tropical_union_tester(tropical_union_generate);
257 // COMPOSITE WEIGHTS AND TESTERS - TESTING
259 // Tests composite weight I/O with parentheses.
260 FLAGS_fst_weight_parentheses = "()";
262 // Unnested composite.
263 tropical_gallic_tester.Test(FLAGS_repeat);
264 tropical_gen_gallic_tester.Test(FLAGS_repeat);
265 tropical_product_tester.Test(FLAGS_repeat);
266 tropical_lexicographic_tester.Test(FLAGS_repeat);
267 tropical_cube_tester.Test(FLAGS_repeat);
268 log_sparse_power_tester.Test(FLAGS_repeat);
269 log_log_expectation_tester.Test(FLAGS_repeat, false);
270 tropical_union_tester.Test(FLAGS_repeat, false);
273 first_nested_product_tester.Test(FLAGS_repeat);
274 second_nested_product_tester.Test(5);
275 nested_product_cube_tester.Test(FLAGS_repeat);
276 sparse_nested_product_cube_tester.Test(FLAGS_repeat);
277 log_log_sparse_expectation_tester.Test(FLAGS_repeat, false);
279 // ... and tests composite weight I/O without parentheses.
280 FLAGS_fst_weight_parentheses = "";
282 // Unnested composite.
283 tropical_gallic_tester.Test(FLAGS_repeat);
284 tropical_product_tester.Test(FLAGS_repeat);
285 tropical_lexicographic_tester.Test(FLAGS_repeat);
286 tropical_cube_tester.Test(FLAGS_repeat);
287 log_sparse_power_tester.Test(FLAGS_repeat);
288 log_log_expectation_tester.Test(FLAGS_repeat, false);
289 tropical_union_tester.Test(FLAGS_repeat, false);
292 second_nested_product_tester.Test(FLAGS_repeat);
293 log_log_sparse_expectation_tester.Test(FLAGS_repeat, false);
295 std::cout << "PASS" << std::endl;