Publishing 2019 R3 content
[platform/upstream/dldt.git] / model-optimizer / mo / pipeline / kaldi.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import logging as log
17
18 import numpy as np
19
20 from extensions.back.CreateConstNodes import CreateConstNodesReplacement
21 from extensions.back.CutMemory import CutMemory
22 from extensions.back.ElementwiseOpsToEltwiseOps import DivideToEltwises, SubtractToEltwises, SimpleEltwiseToEltwiseOp
23 from extensions.back.ForceStrictPrecision import ForceStrictPrecision
24 from extensions.back.LeakyReluToReluWithNegativeSlope import LeakyReluToReluWithNegativeSlope
25 from extensions.back.ParameterToPlaceholder import ParameterToInput
26 from extensions.back.TransposeToPermute import TransposeToPermute
27 from extensions.back.kaldi_remove_memory_output import KaldiRemoveMemoryOutputBackReplacementPattern
28 from extensions.back.remove_last_softmax_pattern import RemoveLastSoftMaxPattern
29 from extensions.front.kaldi.eliminate_redundant_reshape import EliminateRedundantReshape
30 from extensions.front.kaldi.fuse_repeated_reshape import FuseRepeatedReshapes
31 from extensions.front.kaldi.replace_lstm_node_pattern import ReplaceLSTMNodePattern
32 from extensions.middle.EltwiseChecker import EltwiseChecker
33 from extensions.middle.InsertSelect import AddSelectBeforeMemoryNodePattern
34 from extensions.middle.RemoveDuplicationMemory import RemoveMemoryDuplicationPattern, MergeNeighborSplicePattern
35 from extensions.middle.RemoveIdentity import RemoveIdentity
36 from extensions.middle.RemoveUselessCrops import RemoveUselessCropsPattern
37 from extensions.middle.ReplaceMemoryOffsetWithSplice import ReplaceMemoryOffsetNodePattern, \
38     ReplaceMemoryOffsetWithMemoryNodePattern
39 from extensions.middle.ReplacePNorm import ReplacePNormNodePattern
40 from extensions.middle.ReplaceSpliceNodePattern import ReplaceSpliceNodePattern
41 from mo.front.common.register_custom_ops import update_extractors_with_extensions
42 from mo.front.extractor import extract_node_attrs, remove_output_ops
43 from mo.front.kaldi.extractor import kaldi_extractor, kaldi_type_extractors
44 from mo.front.kaldi.loader.loader import load_kaldi_model, read_counts_file
45 from mo.graph.graph import Node
46 from mo.middle.passes.conv import convert_matmul_to_fully_connected
47 from mo.middle.passes.eliminate import graph_clean_up, remove_const_ops
48 from mo.middle.passes.infer import partial_infer
49 from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
50 from mo.pipeline.common import prepare_emit_ir
51 from mo.utils import class_registration
52 from mo.utils.cli_parser import get_meta_info
53 from mo.utils.error import Error
54 from mo.utils.find_inputs import find_outputs
55 from mo.utils.logger import log_step
56 from mo.utils.utils import refer_to_faq_msg
57
58
59 def apply_biases_to_last_layer(graph, counts):
60     """
61     The idea is the following. If the user provides counts file, it is a file that contains log-apriory probabilities,
62     technically it should be subtracted from the bias of the last layer unless it is a SoftMax.
63     
64     Case 1:
65         weights ---\
66         biases  ---\
67     some layer  ---> AffineTransform ---> SoftMax
68     
69     Then, counts are applied to biases of Affine Transform:
70     
71         weights             ---\
72         (biases - counts)   ---\
73     some layer              ---> AffineTransform ---> SoftMax
74     
75     Case 2:
76         weights ---\
77         biases  ---\
78     some layer  ---> AffineTransform
79     
80     Just takes the last layer and updates biases:
81     
82         weights             ---\
83         (biases - counts)   ---\
84     some layer              ---> AffineTransform
85     
86     Parameters
87     ----------
88     graph
89     counts
90
91     Returns
92     -------
93
94     """""
95     outputs_ids = find_outputs(graph)
96     for output in outputs_ids.copy():
97         node = Node(graph, output)
98         if node.in_node().op != 'Memory':
99             continue
100         outputs_ids.remove(output)
101
102     if len(outputs_ids) > 1:
103         raise Error('Ambiguity in applying counts to several outputs.')
104     elif len(outputs_ids) == 0:
105         raise Error('No outputs were found')
106
107     node = Node(graph, outputs_ids[0])
108     target_node = node.in_node()
109     if target_node and target_node['op'] == 'SoftMax':
110         data_node = target_node.in_node()
111         target_node = data_node.in_node()
112
113     biases_node = target_node.in_nodes()[2]  # first - input, second - weights, third - biases
114     if biases_node.value is not None:
115         biases_node.value = np.subtract(biases_node.value, counts)  # pylint: disable=assignment-from-no-return
116     else:
117         biases_node.value = counts * -1
118         biases_node.shape = counts.shape
119
120
121 def driver(argv, input_model, output_model_name, output_dir):
122     log_step(argv.steps, 'LOAD')
123     meta_info = get_meta_info(argv)
124
125     EltwiseChecker.enabled = False
126
127     try:
128         graph = load_kaldi_model(input_model)
129     except Exception as e:
130         raise Error('Model Optimizer is not able to parse Kaldi model {}. '.format(input_model) +
131                     refer_to_faq_msg(91)) from e
132     graph.check_empty_graph('load_kaldi_nnet_model')
133     graph.graph['cmd_params'] = argv
134     graph.graph['fw'] = 'kaldi'
135
136     if graph.graph['cmd_params'].generate_experimental_IR_V10:
137         version = 10
138     else:
139         version = 6
140     graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else version
141
142     update_extractors_with_extensions(kaldi_type_extractors)
143     extract_node_attrs(graph, lambda node: kaldi_extractor(node))
144
145     # --------------------------------- LOAD END ------------------------------------------------------
146     log_step(argv.steps, 'FRONT')
147     ReplaceLSTMNodePattern().find_and_replace_pattern(graph)
148     class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
149     log_step(argv.steps, 'MIDDLE')
150     graph = partial_infer(graph)
151
152     ReplacePNormNodePattern().find_and_replace_pattern(graph)
153     ReplaceMemoryOffsetNodePattern().find_and_replace_pattern(graph)
154     ReplaceMemoryOffsetWithMemoryNodePattern().find_and_replace_pattern(graph)
155     RemoveMemoryDuplicationPattern().find_and_replace_pattern(graph)
156     MergeNeighborSplicePattern().find_and_replace_pattern(graph)
157     RemoveUselessCropsPattern().find_and_replace_pattern(graph)
158     RemoveIdentity().find_and_replace_pattern(graph)
159     graph_clean_up(graph)
160
161     AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
162
163     ReplaceSpliceNodePattern().find_and_replace_pattern(graph)
164     graph_clean_up(graph)
165
166     # The order is intentional, firstly eliminate repeated, then remove redundant
167     FuseRepeatedReshapes().find_and_replace_pattern(graph)
168     EliminateRedundantReshape().find_and_replace_pattern(graph)
169     graph_clean_up(graph)
170     graph.check_empty_graph('partial_infer')
171     if argv.counts:
172         try:
173             counts = read_counts_file(argv.counts)
174         except Exception as e:
175             raise Error('Model Optimizer is not able to read counts file {}'.format(argv.counts) +
176                         refer_to_faq_msg(92)) from e
177
178         apply_biases_to_last_layer(graph, counts)
179
180     if argv.remove_output_softmax:
181         RemoveLastSoftMaxPattern().find_and_replace_pattern(graph)
182         graph_clean_up(graph)
183         log.debug("After removing softmax")
184         graph.print_graph_stat()
185
186     log_step(argv.steps, 'BACK')
187     LeakyReluToReluWithNegativeSlope().find_and_replace_pattern(graph)
188     TransposeToPermute().find_and_replace_pattern(graph)
189     DivideToEltwises().find_and_replace_pattern(graph)
190     SubtractToEltwises().find_and_replace_pattern(graph)
191     SimpleEltwiseToEltwiseOp().find_and_replace_pattern(graph)
192     for_graph_and_each_sub_graph_recursively(graph, convert_matmul_to_fully_connected)
193
194     # Intentionally after all transformations
195     if argv.remove_memory:
196         CutMemory().find_and_replace_pattern(graph)
197         graph_clean_up(graph)
198     ParameterToInput().find_and_replace_pattern(graph)
199
200     KaldiRemoveMemoryOutputBackReplacementPattern().find_and_replace_pattern(graph)
201     ForceStrictPrecision().find_and_replace_pattern(graph)
202     remove_const_ops(graph)
203     CreateConstNodesReplacement().find_and_replace_pattern(graph)
204
205     remove_output_ops(graph)
206     log_step(argv.steps, 'EMIT')
207     prepare_emit_ir(graph, argv.data_type, output_dir, output_model_name, meta_info=meta_info)
208     return 0