25d9fdbf3269115ca644083af121977363c1d136
[platform/upstream/openfst.git] / src / test / weight_test.cc
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Regression test for FST weights.
5
6 #include <cstdlib>
7 #include <ctime>
8
9 #include <fst/log.h>
10 #include <fst/expectation-weight.h>
11 #include <fst/float-weight.h>
12 #include <fst/lexicographic-weight.h>
13 #include <fst/power-weight.h>
14 #include <fst/product-weight.h>
15 #include <fst/signed-log-weight.h>
16 #include <fst/sparse-power-weight.h>
17 #include <fst/string-weight.h>
18 #include <fst/union-weight.h>
19 #include "./weight-tester.h"
20
21 DEFINE_int32(seed, -1, "random seed");
22 DEFINE_int32(repeat, 10000, "number of test repetitions");
23
24 namespace {
25
26 using fst::Adder;
27 using fst::ExpectationWeight;
28 using fst::GALLIC;
29 using fst::GallicWeight;
30 using fst::LexicographicWeight;
31 using fst::LogWeight;
32 using fst::LogWeightTpl;
33 using fst::MinMaxWeight;
34 using fst::MinMaxWeightTpl;
35 using fst::NaturalLess;
36 using fst::PowerWeight;
37 using fst::ProductWeight;
38 using fst::SignedLogWeight;
39 using fst::SignedLogWeightTpl;
40 using fst::SparsePowerWeight;
41 using fst::StringWeight;
42 using fst::STRING_LEFT;
43 using fst::STRING_RIGHT;
44 using fst::TropicalWeight;
45 using fst::TropicalWeightTpl;
46 using fst::UnionWeight;
47 using fst::WeightGenerate;
48 using fst::WeightTester;
49
50 template <class T>
51 void TestTemplatedWeights(int repeat) {
52   using TropicalWeightGenerate = WeightGenerate<TropicalWeightTpl<T>>;
53   TropicalWeightGenerate tropical_generate;
54   WeightTester<TropicalWeightTpl<T>, TropicalWeightGenerate> tropical_tester(
55       tropical_generate);
56   tropical_tester.Test(repeat);
57
58   using LogWeightGenerate = WeightGenerate<LogWeightTpl<T>>;
59   LogWeightGenerate log_generate;
60   WeightTester<LogWeightTpl<T>, LogWeightGenerate> log_tester(log_generate);
61   log_tester.Test(repeat);
62
63   using MinMaxWeightGenerate = WeightGenerate<MinMaxWeightTpl<T>>;
64   MinMaxWeightGenerate minmax_generate(true);
65   WeightTester<MinMaxWeightTpl<T>, MinMaxWeightGenerate> minmax_tester(
66       minmax_generate);
67   minmax_tester.Test(repeat);
68
69   using SignedLogWeightGenerate = WeightGenerate<SignedLogWeightTpl<T>>;
70   SignedLogWeightGenerate signedlog_generate;
71   WeightTester<SignedLogWeightTpl<T>, SignedLogWeightGenerate>
72       signedlog_tester(signedlog_generate);
73   signedlog_tester.Test(repeat);
74 }
75
76 template <class Weight>
77 void TestAdder(int n) {
78   Weight sum = Weight::Zero();
79   Adder<Weight> adder;
80   for (int i = 0; i < n; ++i) {
81     sum = Plus(sum, Weight::One());
82     adder.Add(Weight::One());
83   }
84   CHECK(ApproxEqual(sum, adder.Sum()));
85 }
86
87 template <class Weight>
88 void TestSignedAdder(int n) {
89   Weight sum = Weight::Zero();
90   Adder<Weight> adder;
91   const Weight minus_one = Minus(Weight::Zero(), Weight::One());
92   for (int i = 0; i < n; ++i) {
93     if (i < n/4 || i > 3*n/4) {
94       sum = Plus(sum, Weight::One());
95       adder.Add(Weight::One());
96     } else {
97       sum = Minus(sum, Weight::One());
98       adder.Add(minus_one);
99     }
100   }
101   CHECK(ApproxEqual(sum, adder.Sum()));
102 }
103
104 }  // namespace
105
106 int main(int argc, char **argv) {
107   std::set_new_handler(FailedNewHandler);
108   SET_FLAGS(argv[0], &argc, &argv, true);
109
110   LOG(INFO) << "Seed = " << FLAGS_seed;
111   srand(FLAGS_seed);
112
113   TestTemplatedWeights<float>(FLAGS_repeat);
114   TestTemplatedWeights<double>(FLAGS_repeat);
115   FLAGS_fst_weight_parentheses = "()";
116   TestTemplatedWeights<float>(FLAGS_repeat);
117   TestTemplatedWeights<double>(FLAGS_repeat);
118   FLAGS_fst_weight_parentheses = "";
119
120   // Makes sure type names for templated weights are consistent.
121   CHECK(TropicalWeight::Type() == "tropical");
122   CHECK(TropicalWeightTpl<double>::Type() != TropicalWeightTpl<float>::Type());
123   CHECK(LogWeight::Type() == "log");
124   CHECK(LogWeightTpl<double>::Type() != LogWeightTpl<float>::Type());
125   TropicalWeightTpl<double> w(15.0);
126   TropicalWeight tw(15.0);
127
128   TestAdder<TropicalWeight>(1000);
129   TestAdder<LogWeight>(1000);
130   TestSignedAdder<SignedLogWeight>(1000);
131
132   return 0;
133
134   using LeftStringWeight = StringWeight<int>;
135   using LeftStringWeightGenerate = WeightGenerate<LeftStringWeight>;
136   LeftStringWeightGenerate left_string_generate;
137   WeightTester<LeftStringWeight, LeftStringWeightGenerate> left_string_tester(
138       left_string_generate);
139   left_string_tester.Test(FLAGS_repeat);
140
141   using RightStringWeight = StringWeight<int, STRING_RIGHT>;
142   using RightStringWeightGenerate = WeightGenerate<RightStringWeight>;
143   RightStringWeightGenerate right_string_generate;
144   WeightTester<RightStringWeight, RightStringWeightGenerate>
145       right_string_tester(right_string_generate);
146   right_string_tester.Test(FLAGS_repeat);
147
148   // COMPOSITE WEIGHTS AND TESTERS - DEFINITIONS
149
150   using TropicalGallicWeight = GallicWeight<int, TropicalWeight>;
151   using TropicalGallicWeightGenerate = WeightGenerate<TropicalGallicWeight>;
152   TropicalGallicWeightGenerate tropical_gallic_generate(true);
153   WeightTester<TropicalGallicWeight, TropicalGallicWeightGenerate>
154       tropical_gallic_tester(tropical_gallic_generate);
155
156   using TropicalGenGallicWeight = GallicWeight<int, TropicalWeight, GALLIC>;
157   using TropicalGenGallicWeightGenerate =
158       WeightGenerate<TropicalGenGallicWeight>;
159   TropicalGenGallicWeightGenerate tropical_gen_gallic_generate(false);
160   WeightTester<TropicalGenGallicWeight, TropicalGenGallicWeightGenerate>
161       tropical_gen_gallic_tester(tropical_gen_gallic_generate);
162
163   using TropicalProductWeight = ProductWeight<TropicalWeight, TropicalWeight>;
164   using TropicalProductWeightGenerate = WeightGenerate<TropicalProductWeight>;
165   TropicalProductWeightGenerate tropical_product_generate;
166   WeightTester<TropicalProductWeight, TropicalProductWeightGenerate>
167       tropical_product_tester(tropical_product_generate);
168
169   using TropicalLexicographicWeight =
170       LexicographicWeight<TropicalWeight, TropicalWeight>;
171   using TropicalLexicographicWeightGenerate =
172       WeightGenerate<TropicalLexicographicWeight>;
173   TropicalLexicographicWeightGenerate tropical_lexicographic_generate;
174   WeightTester<TropicalLexicographicWeight,
175                TropicalLexicographicWeightGenerate>
176       tropical_lexicographic_tester(tropical_lexicographic_generate);
177
178   using TropicalCubeWeight = PowerWeight<TropicalWeight, 3>;
179   using TropicalCubeWeightGenerate = WeightGenerate<TropicalCubeWeight>;
180   TropicalCubeWeightGenerate tropical_cube_generate;
181   WeightTester<TropicalCubeWeight, TropicalCubeWeightGenerate>
182       tropical_cube_tester(tropical_cube_generate);
183
184   using FirstNestedProductWeight =
185       ProductWeight<TropicalProductWeight, TropicalWeight>;
186   using FirstNestedProductWeightGenerate =
187       WeightGenerate<FirstNestedProductWeight>;
188   FirstNestedProductWeightGenerate first_nested_product_generate;
189   WeightTester<FirstNestedProductWeight, FirstNestedProductWeightGenerate>
190       first_nested_product_tester(first_nested_product_generate);
191
192   using SecondNestedProductWeight =
193       ProductWeight<TropicalWeight, TropicalProductWeight>;
194   using SecondNestedProductWeightGenerate =
195       WeightGenerate<SecondNestedProductWeight>;
196   SecondNestedProductWeightGenerate second_nested_product_generate;
197   WeightTester<SecondNestedProductWeight, SecondNestedProductWeightGenerate>
198       second_nested_product_tester(second_nested_product_generate);
199
200   using NestedProductCubeWeight = PowerWeight<FirstNestedProductWeight, 3>;
201   using NestedProductCubeWeightGenerate =
202       WeightGenerate<NestedProductCubeWeight>;
203   NestedProductCubeWeightGenerate nested_product_cube_generate;
204   WeightTester<NestedProductCubeWeight, NestedProductCubeWeightGenerate>
205       nested_product_cube_tester(nested_product_cube_generate);
206
207   using SparseNestedProductCubeWeight =
208       SparsePowerWeight<NestedProductCubeWeight, size_t>;
209   using SparseNestedProductCubeWeightGenerate =
210       WeightGenerate<SparseNestedProductCubeWeight>;
211   SparseNestedProductCubeWeightGenerate sparse_nested_product_cube_generate;
212   WeightTester<SparseNestedProductCubeWeight,
213                SparseNestedProductCubeWeightGenerate>
214       sparse_nested_product_cube_tester(sparse_nested_product_cube_generate);
215
216   using LogSparsePowerWeight = SparsePowerWeight<LogWeight, size_t>;
217   using LogSparsePowerWeightGenerate = WeightGenerate<LogSparsePowerWeight>;
218   LogSparsePowerWeightGenerate log_sparse_power_generate;
219   WeightTester<LogSparsePowerWeight, LogSparsePowerWeightGenerate>
220       log_sparse_power_tester(log_sparse_power_generate);
221
222   using LogLogExpectationWeight = ExpectationWeight<LogWeight, LogWeight>;
223   using LogLogExpectationWeightGenerate =
224       WeightGenerate<LogLogExpectationWeight>;
225   LogLogExpectationWeightGenerate log_log_expectation_generate;
226   WeightTester<LogLogExpectationWeight, LogLogExpectationWeightGenerate>
227       log_log_expectation_tester(log_log_expectation_generate);
228
229   using LogLogSparseExpectationWeight =
230       ExpectationWeight<LogWeight, LogSparsePowerWeight>;
231   using LogLogSparseExpectationWeightGenerate =
232       WeightGenerate<LogLogSparseExpectationWeight>;
233   LogLogSparseExpectationWeightGenerate log_log_sparse_expectation_generate;
234   WeightTester<LogLogSparseExpectationWeight,
235                LogLogSparseExpectationWeightGenerate>
236       log_log_sparse_expectation_tester(log_log_sparse_expectation_generate);
237
238   struct UnionWeightOptions {
239     using Compare = NaturalLess<TropicalWeight>;
240
241     struct Merge {
242       TropicalWeight operator()(const TropicalWeight &w1,
243                                 const TropicalWeight &w2) const {
244         return w1;
245       }
246     };
247
248     using ReverseOptions = UnionWeightOptions;
249   };
250
251   using TropicalUnionWeight = UnionWeight<TropicalWeight, UnionWeightOptions>;
252   using TropicalUnionWeightGenerate = WeightGenerate<TropicalUnionWeight>;
253   TropicalUnionWeightGenerate tropical_union_generate;
254   WeightTester<TropicalUnionWeight, TropicalUnionWeightGenerate>
255       tropical_union_tester(tropical_union_generate);
256
257   // COMPOSITE WEIGHTS AND TESTERS - TESTING
258
259   // Tests composite weight I/O with parentheses.
260   FLAGS_fst_weight_parentheses = "()";
261
262   // Unnested composite.
263   tropical_gallic_tester.Test(FLAGS_repeat);
264   tropical_gen_gallic_tester.Test(FLAGS_repeat);
265   tropical_product_tester.Test(FLAGS_repeat);
266   tropical_lexicographic_tester.Test(FLAGS_repeat);
267   tropical_cube_tester.Test(FLAGS_repeat);
268   log_sparse_power_tester.Test(FLAGS_repeat);
269   log_log_expectation_tester.Test(FLAGS_repeat, false);
270   tropical_union_tester.Test(FLAGS_repeat, false);
271
272   // Nested composite.
273   first_nested_product_tester.Test(FLAGS_repeat);
274   second_nested_product_tester.Test(5);
275   nested_product_cube_tester.Test(FLAGS_repeat);
276   sparse_nested_product_cube_tester.Test(FLAGS_repeat);
277   log_log_sparse_expectation_tester.Test(FLAGS_repeat, false);
278
279   // ... and tests composite weight I/O without parentheses.
280   FLAGS_fst_weight_parentheses = "";
281
282   // Unnested composite.
283   tropical_gallic_tester.Test(FLAGS_repeat);
284   tropical_product_tester.Test(FLAGS_repeat);
285   tropical_lexicographic_tester.Test(FLAGS_repeat);
286   tropical_cube_tester.Test(FLAGS_repeat);
287   log_sparse_power_tester.Test(FLAGS_repeat);
288   log_log_expectation_tester.Test(FLAGS_repeat, false);
289   tropical_union_tester.Test(FLAGS_repeat, false);
290
291   // Nested composite.
292   second_nested_product_tester.Test(FLAGS_repeat);
293   log_log_sparse_expectation_tester.Test(FLAGS_repeat, false);
294
295   std::cout << "PASS" << std::endl;
296
297   return 0;
298 }