Imported Upstream version 1.6.4
[platform/upstream/openfst.git] / src / extensions / pdt / pdtreplace.cc
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Converts an RTN represented by FSTs and non-terminal labels into a PDT.
5
6 #include <cstring>
7
8 #include <string>
9 #include <vector>
10
11 #include <fst/extensions/pdt/getters.h>
12 #include <fst/extensions/pdt/pdtscript.h>
13 #include <fst/util.h>
14 #include <fst/vector-fst.h>
15
16 DEFINE_string(pdt_parentheses, "", "PDT parenthesis label pairs");
17 DEFINE_string(pdt_parser_type, "left",
18               "Construction method, one of: \"left\", \"left_sr\"");
19 DEFINE_int64(start_paren_labels, fst::kNoLabel,
20              "Index to use for the first inserted parentheses; if not "
21              "specified, the next available label beyond the highest output "
22              "label is used");
23 DEFINE_string(left_paren_prefix, "(_", "Prefix to attach to SymbolTable "
24               "labels for inserted left parentheses");
25 DEFINE_string(right_paren_prefix, ")_", "Prefix to attach to SymbolTable "
26               "labels for inserted right parentheses");
27
28 void Cleanup(std::vector<fst::script::LabelFstClassPair> *pairs) {
29   for (const auto &pair : *pairs) {
30     delete pair.second;
31   }
32   pairs->clear();
33 }
34
35 int main(int argc, char **argv) {
36   namespace s = fst::script;
37   using fst::script::FstClass;
38   using fst::script::VectorFstClass;
39   using fst::PdtParserType;
40   using fst::WriteLabelPairs;
41
42   string usage = "Converts an RTN represented by FSTs";
43   usage += " and non-terminal labels into PDT.\n\n  Usage: ";
44   usage += argv[0];
45   usage += " root.fst rootlabel [rule1.fst label1 ...] [out.fst]\n";
46
47   std::set_new_handler(FailedNewHandler);
48   SET_FLAGS(usage.c_str(), &argc, &argv, true);
49   if (argc < 4) {
50     ShowUsage();
51     return 1;
52   }
53
54   const string in_name = argv[1];
55   const string out_name = argc % 2 == 0 ? argv[argc - 1] : "";
56
57   auto *ifst = FstClass::Read(in_name);
58   if (!ifst) return 1;
59
60   PdtParserType parser_type;
61   if (!s::GetPdtParserType(FLAGS_pdt_parser_type, &parser_type)) {
62     LOG(ERROR) << argv[0] << ": Unknown PDT parser type: "
63                << FLAGS_pdt_parser_type;
64     delete ifst;
65     return 1;
66   }
67
68   std::vector<s::LabelFstClassPair> pairs;
69   // Note that if the root label is beyond the range of the underlying FST's
70   // labels, truncation will occur.
71   const auto root = atoll(argv[2]);
72   pairs.emplace_back(root, ifst);
73
74   for (auto i = 3; i < argc - 1; i += 2) {
75     ifst = FstClass::Read(argv[i]);
76     if (!ifst) {
77       Cleanup(&pairs);
78       return 1;
79     }
80     // Note that if the root label is beyond the range of the underlying FST's
81     // labels, truncation will occur.
82     const auto label = atoll(argv[i + 1]);
83     pairs.emplace_back(label, ifst);
84   }
85
86   VectorFstClass ofst(ifst->ArcType());
87   std::vector<s::LabelPair> parens;
88   s::PdtReplace(pairs, &ofst, &parens, root, parser_type,
89                 FLAGS_start_paren_labels, FLAGS_left_paren_prefix,
90                 FLAGS_right_paren_prefix);
91   Cleanup(&pairs);
92
93   if (!FLAGS_pdt_parentheses.empty()) {
94     if (!WriteLabelPairs(FLAGS_pdt_parentheses, parens)) return 1;
95   }
96
97   ofst.Write(out_name);
98
99   return 0;
100 }