2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 from extensions.back.CreateConstNodes import CreateConstNodesReplacement
21 from extensions.back.CutMemory import CutMemory
22 from extensions.back.ElementwiseOpsToEltwiseOps import DivideToEltwises, SubtractToEltwises, SimpleEltwiseToEltwiseOp
23 from extensions.back.ForceStrictPrecision import ForceStrictPrecision
24 from extensions.back.LeakyReluToReluWithNegativeSlope import LeakyReluToReluWithNegativeSlope
25 from extensions.back.ParameterToPlaceholder import ParameterToInput
26 from extensions.back.TransposeToPermute import TransposeToPermute
27 from extensions.back.kaldi_remove_memory_output import KaldiRemoveMemoryOutputBackReplacementPattern
28 from extensions.back.remove_last_softmax_pattern import RemoveLastSoftMaxPattern
29 from extensions.front.kaldi.eliminate_redundant_reshape import EliminateRedundantReshape
30 from extensions.front.kaldi.fuse_repeated_reshape import FuseRepeatedReshapes
31 from extensions.front.kaldi.replace_lstm_node_pattern import ReplaceLSTMNodePattern
32 from extensions.middle.EltwiseChecker import EltwiseChecker
33 from extensions.middle.InsertSelect import AddSelectBeforeMemoryNodePattern
34 from extensions.middle.RemoveDuplicationMemory import RemoveMemoryDuplicationPattern, MergeNeighborSplicePattern
35 from extensions.middle.RemoveIdentity import RemoveIdentity
36 from extensions.middle.RemoveUselessCrops import RemoveUselessCropsPattern
37 from extensions.middle.ReplaceMemoryOffsetWithSplice import ReplaceMemoryOffsetNodePattern, \
38 ReplaceMemoryOffsetWithMemoryNodePattern
39 from extensions.middle.ReplacePNorm import ReplacePNormNodePattern
40 from extensions.middle.ReplaceSpliceNodePattern import ReplaceSpliceNodePattern
41 from mo.front.common.register_custom_ops import update_extractors_with_extensions
42 from mo.front.extractor import extract_node_attrs, remove_output_ops
43 from mo.front.kaldi.extractor import kaldi_extractor, kaldi_type_extractors
44 from mo.front.kaldi.loader.loader import load_kaldi_model, read_counts_file
45 from mo.graph.graph import Node
46 from mo.middle.passes.conv import convert_matmul_to_fully_connected
47 from mo.middle.passes.eliminate import graph_clean_up, remove_const_ops
48 from mo.middle.passes.infer import partial_infer
49 from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
50 from mo.pipeline.common import prepare_emit_ir
51 from mo.utils import class_registration
52 from mo.utils.cli_parser import get_meta_info
53 from mo.utils.error import Error
54 from mo.utils.find_inputs import find_outputs
55 from mo.utils.logger import log_step
56 from mo.utils.utils import refer_to_faq_msg
59 def apply_biases_to_last_layer(graph, counts):
61 The idea is the following. If the user provides counts file, it is a file that contains log-apriory probabilities,
62 technically it should be subtracted from the bias of the last layer unless it is a SoftMax.
67 some layer ---> AffineTransform ---> SoftMax
69 Then, counts are applied to biases of Affine Transform:
72 (biases - counts) ---\
73 some layer ---> AffineTransform ---> SoftMax
78 some layer ---> AffineTransform
80 Just takes the last layer and updates biases:
83 (biases - counts) ---\
84 some layer ---> AffineTransform
95 outputs_ids = find_outputs(graph)
96 for output in outputs_ids.copy():
97 node = Node(graph, output)
98 if node.in_node().op != 'Memory':
100 outputs_ids.remove(output)
102 if len(outputs_ids) > 1:
103 raise Error('Ambiguity in applying counts to several outputs.')
104 elif len(outputs_ids) == 0:
105 raise Error('No outputs were found')
107 node = Node(graph, outputs_ids[0])
108 target_node = node.in_node()
109 if target_node and target_node['op'] == 'SoftMax':
110 data_node = target_node.in_node()
111 target_node = data_node.in_node()
113 biases_node = target_node.in_nodes()[2] # first - input, second - weights, third - biases
114 if biases_node.value is not None:
115 biases_node.value = np.subtract(biases_node.value, counts) # pylint: disable=assignment-from-no-return
117 biases_node.value = counts * -1
118 biases_node.shape = counts.shape
121 def driver(argv, input_model, output_model_name, output_dir):
122 log_step(argv.steps, 'LOAD')
123 meta_info = get_meta_info(argv)
125 EltwiseChecker.enabled = False
128 graph = load_kaldi_model(input_model)
129 except Exception as e:
130 raise Error('Model Optimizer is not able to parse Kaldi model {}. '.format(input_model) +
131 refer_to_faq_msg(91)) from e
132 graph.check_empty_graph('load_kaldi_nnet_model')
133 graph.graph['cmd_params'] = argv
134 graph.graph['fw'] = 'kaldi'
136 if graph.graph['cmd_params'].generate_experimental_IR_V10:
140 graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else version
142 update_extractors_with_extensions(kaldi_type_extractors)
143 extract_node_attrs(graph, lambda node: kaldi_extractor(node))
145 # --------------------------------- LOAD END ------------------------------------------------------
146 log_step(argv.steps, 'FRONT')
147 ReplaceLSTMNodePattern().find_and_replace_pattern(graph)
148 class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
149 log_step(argv.steps, 'MIDDLE')
150 graph = partial_infer(graph)
152 ReplacePNormNodePattern().find_and_replace_pattern(graph)
153 ReplaceMemoryOffsetNodePattern().find_and_replace_pattern(graph)
154 ReplaceMemoryOffsetWithMemoryNodePattern().find_and_replace_pattern(graph)
155 RemoveMemoryDuplicationPattern().find_and_replace_pattern(graph)
156 MergeNeighborSplicePattern().find_and_replace_pattern(graph)
157 RemoveUselessCropsPattern().find_and_replace_pattern(graph)
158 RemoveIdentity().find_and_replace_pattern(graph)
159 graph_clean_up(graph)
161 AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
163 ReplaceSpliceNodePattern().find_and_replace_pattern(graph)
164 graph_clean_up(graph)
166 # The order is intentional, firstly eliminate repeated, then remove redundant
167 FuseRepeatedReshapes().find_and_replace_pattern(graph)
168 EliminateRedundantReshape().find_and_replace_pattern(graph)
169 graph_clean_up(graph)
170 graph.check_empty_graph('partial_infer')
173 counts = read_counts_file(argv.counts)
174 except Exception as e:
175 raise Error('Model Optimizer is not able to read counts file {}'.format(argv.counts) +
176 refer_to_faq_msg(92)) from e
178 apply_biases_to_last_layer(graph, counts)
180 if argv.remove_output_softmax:
181 RemoveLastSoftMaxPattern().find_and_replace_pattern(graph)
182 graph_clean_up(graph)
183 log.debug("After removing softmax")
184 graph.print_graph_stat()
186 log_step(argv.steps, 'BACK')
187 LeakyReluToReluWithNegativeSlope().find_and_replace_pattern(graph)
188 TransposeToPermute().find_and_replace_pattern(graph)
189 DivideToEltwises().find_and_replace_pattern(graph)
190 SubtractToEltwises().find_and_replace_pattern(graph)
191 SimpleEltwiseToEltwiseOp().find_and_replace_pattern(graph)
192 for_graph_and_each_sub_graph_recursively(graph, convert_matmul_to_fully_connected)
194 # Intentionally after all transformations
195 if argv.remove_memory:
196 CutMemory().find_and_replace_pattern(graph)
197 graph_clean_up(graph)
198 ParameterToInput().find_and_replace_pattern(graph)
200 KaldiRemoveMemoryOutputBackReplacementPattern().find_and_replace_pattern(graph)
201 ForceStrictPrecision().find_and_replace_pattern(graph)
202 remove_const_ops(graph)
203 CreateConstNodesReplacement().find_and_replace_pattern(graph)
205 remove_output_ops(graph)
206 log_step(argv.steps, 'EMIT')
207 prepare_emit_ir(graph, argv.data_type, output_dir, output_model_name, meta_info=meta_info)