12adf266667934f0eddce1c6a13b4ec6085241a4
[platform/upstream/openfst.git] / src / extensions / pdt / pdtreplace.cc
1 // See www.openfst.org for extensive documentation on this weighted
2 // finite-state transducer library.
3 //
4 // Converts an RTN represented by FSTs and non-terminal labels into a PDT.
5
6 #include <cstring>
7
8 #include <string>
9 #include <vector>
10
11 #include <fst/extensions/pdt/getters.h>
12 #include <fst/extensions/pdt/pdtscript.h>
13 #include <fst/util.h>
14 #include <fst/vector-fst.h>
15
16 DEFINE_string(pdt_parentheses, "", "PDT parenthesis label pairs");
17 DEFINE_string(pdt_parser_type, "left",
18               "Construction method, one of: \"left\", \"left_sr\"");
19 DEFINE_int64(start_paren_labels, fst::kNoLabel,
20              "Index to use for the first inserted parentheses; if not "
21              "specified, the next available label beyond the highest output "
22              "label is used");
23 DEFINE_string(left_paren_prefix, "(_", "Prefix to attach to SymbolTable "
24               "labels for inserted left parentheses");
25 DEFINE_string(right_paren_prefix, ")_", "Prefix to attach to SymbolTable "
26               "labels for inserted right parentheses");
27
28 int main(int argc, char **argv) {
29   namespace s = fst::script;
30   using fst::script::FstClass;
31   using fst::script::VectorFstClass;
32   using fst::PdtParserType;
33   using fst::WriteLabelPairs;
34
35   string usage = "Converts an RTN represented by FSTs";
36   usage += " and non-terminal labels into PDT.\n\n  Usage: ";
37   usage += argv[0];
38   usage += " root.fst rootlabel [rule1.fst label1 ...] [out.fst]\n";
39
40   std::set_new_handler(FailedNewHandler);
41   SET_FLAGS(usage.c_str(), &argc, &argv, true);
42   if (argc < 4) {
43     ShowUsage();
44     return 1;
45   }
46
47   const string in_name = argv[1];
48   const string out_name = argc % 2 == 0 ? argv[argc - 1] : "";
49
50   // Replace takes ownership of the pointer of FST arrays, deleting all such
51   // pointers when the underlying ReplaceFst is destroyed.
52   auto *ifst = FstClass::Read(in_name);
53   if (!ifst) return 1;
54
55   PdtParserType parser_type;
56   if (!s::GetPdtParserType(FLAGS_pdt_parser_type, &parser_type)) {
57     LOG(ERROR) << argv[0] << ": Unknown PDT parser type: "
58                << FLAGS_pdt_parser_type;
59     return 1;
60   }
61
62   std::vector<s::LabelFstClassPair> pairs;
63   // Note that if the root label is beyond the range of the underlying FST's
64   // labels, truncation will occur.
65   const auto root = atoll(argv[2]);
66   pairs.emplace_back(root, ifst);
67
68   for (auto i = 3; i < argc - 1; i += 2) {
69     ifst = FstClass::Read(argv[i]);
70     if (!ifst) return 1;
71     // Note that if the root label is beyond the range of the underlying FST's
72     // labels, truncation will occur.
73     const auto label = atoll(argv[i + 1]);
74     pairs.emplace_back(label, ifst);
75   }
76
77   VectorFstClass ofst(ifst->ArcType());
78   std::vector<s::LabelPair> parens;
79   s::PdtReplace(pairs, &ofst, &parens, root, parser_type,
80                 FLAGS_start_paren_labels, FLAGS_left_paren_prefix,
81                 FLAGS_right_paren_prefix);
82
83   if (!FLAGS_pdt_parentheses.empty()) {
84     if (!WriteLabelPairs(FLAGS_pdt_parentheses, parens)) return 1;
85   }
86
87   ofst.Write(out_name);
88
89   return 0;
90 }