Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / pipeline / kaldi.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import logging as log
17
18 import numpy as np
19
20 from extensions.back.CreateConstNodes import CreateConstNodesReplacement
21 from extensions.back.kaldi_remove_memory_output import KaldiRemoveMemoryOutputBackReplacementPattern
22 from extensions.back.remove_last_softmax_pattern import RemoveLastSoftMaxPattern
23 from extensions.front.kaldi.eliminate_redundant_reshape import EliminateRedundantReshape
24 from extensions.front.kaldi.fuse_repeated_reshape import FuseRepeatedReshapes
25 from extensions.middle.EltwiseChecker import EltwiseChecker
26 from mo.front.common.register_custom_ops import update_extractors_with_extensions
27 from mo.front.extractor import extract_node_attrs, remove_output_ops
28 from mo.front.kaldi.extractor import kaldi_extractor, kaldi_type_extractors
29 from mo.front.kaldi.loader.loader import load_kaldi_model, read_counts_file
30 from mo.graph.graph import Node
31 from mo.middle.passes.eliminate import graph_clean_up, remove_const_ops
32 from mo.middle.passes.infer import partial_infer
33 from mo.pipeline.common import prepare_emit_ir
34 from mo.utils import class_registration
35 from mo.utils.cli_parser import get_meta_info
36 from mo.utils.error import Error
37 from mo.utils.find_inputs import find_outputs
38 from mo.utils.utils import refer_to_faq_msg
39
40
41 def apply_biases_to_last_layer(graph, counts):
42     """
43     The idea is the following. If the user provides counts file, it is a file that contains log-apriory probabilities,
44     technically it should be subtracted from the bias of the last layer unless it is a SoftMax.
45     
46     Case 1:
47         weights ---\
48         biases  ---\
49     some layer  ---> AffineTransform ---> SoftMax
50     
51     Then, counts are applied to biases of Affine Transform:
52     
53         weights             ---\
54         (biases - counts)   ---\
55     some layer              ---> AffineTransform ---> SoftMax
56     
57     Case 2:
58         weights ---\
59         biases  ---\
60     some layer  ---> AffineTransform
61     
62     Just takes the last layer and updates biases:
63     
64         weights             ---\
65         (biases - counts)   ---\
66     some layer              ---> AffineTransform
67     
68     Parameters
69     ----------
70     graph
71     counts
72
73     Returns
74     -------
75
76     """""
77     outputs_ids = find_outputs(graph)
78     for output in outputs_ids.copy():
79         node = Node(graph, output)
80         if node.in_node().op != 'Memory':
81             continue
82         outputs_ids.remove(output)
83
84     if len(outputs_ids) > 1:
85         raise Error('Ambiguity in applying counts to several outputs.')
86     elif len(outputs_ids) == 0:
87         raise Error('No outputs were found')
88
89     node = Node(graph, outputs_ids[0])
90     target_node = node.in_node()
91     if target_node and target_node['op'] == 'SoftMax':
92         data_node = target_node.in_node()
93         target_node = data_node.in_node()
94
95     biases_node = target_node.in_nodes()[2]  # first - input, second - weights, third - biases
96     if biases_node.value is not None:
97         biases_node.value = np.subtract(biases_node.value, counts)  # pylint: disable=assignment-from-no-return
98     else:
99         biases_node.value = counts * -1
100         biases_node.shape = counts.shape
101
102
103 def driver(argv, input_model, output_model_name, output_dir):
104     meta_info = get_meta_info(argv)
105
106     EltwiseChecker.enabled = False
107
108     try:
109         graph, input_shapes = load_kaldi_model(input_model)
110     except Exception as e:
111         raise Error('Model Optimizer is not able to read Kaldi model {}. '.format(input_model) +
112                     refer_to_faq_msg(91)) from e
113     graph.check_empty_graph('load_kaldi_nnet_model')
114     graph.graph['cmd_params'] = argv
115     graph.graph['fw'] = 'kaldi'
116     graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5
117     update_extractors_with_extensions(kaldi_type_extractors)
118     extract_node_attrs(graph, lambda node: kaldi_extractor(node))
119
120     # --------------------------------- LOAD END ------------------------------------------------------
121     class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
122
123     graph = partial_infer(graph)
124
125     # The order is intentional, firstly eliminate repeated, then remove redundant
126     FuseRepeatedReshapes().find_and_replace_pattern(graph)
127     EliminateRedundantReshape().find_and_replace_pattern(graph)
128     graph.check_empty_graph('partial_infer')
129     if argv.counts:
130         try:
131             counts = read_counts_file(argv.counts)
132         except Exception as e:
133             raise Error('Model Optimizer is not able to read counts file {}'.format(argv.counts) +
134                         refer_to_faq_msg(92)) from e
135
136         apply_biases_to_last_layer(graph, counts)
137
138     if argv.remove_output_softmax:
139         RemoveLastSoftMaxPattern().find_and_replace_pattern(graph)
140         graph_clean_up(graph)
141         log.debug("After removing softmax")
142         graph.print_graph_stat()
143
144     # Intentionally after all transformations
145     KaldiRemoveMemoryOutputBackReplacementPattern().find_and_replace_pattern(graph)
146
147     remove_const_ops(graph)
148     CreateConstNodesReplacement().find_and_replace_pattern(graph)
149
150     remove_output_ops(graph)
151
152     prepare_emit_ir(graph, argv.data_type, output_dir, output_model_name, meta_info=meta_info)
153     return 0