2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 from extensions.back.CreateConstNodes import CreateConstNodesReplacement
21 from extensions.back.kaldi_remove_memory_output import KaldiRemoveMemoryOutputBackReplacementPattern
22 from extensions.back.remove_last_softmax_pattern import RemoveLastSoftMaxPattern
23 from extensions.front.kaldi.eliminate_redundant_reshape import EliminateRedundantReshape
24 from extensions.front.kaldi.fuse_repeated_reshape import FuseRepeatedReshapes
25 from extensions.middle.EltwiseChecker import EltwiseChecker
26 from mo.front.common.register_custom_ops import update_extractors_with_extensions
27 from mo.front.extractor import extract_node_attrs, remove_output_ops
28 from mo.front.kaldi.extractor import kaldi_extractor, kaldi_type_extractors
29 from mo.front.kaldi.loader.loader import load_kaldi_model, read_counts_file
30 from mo.graph.graph import Node
31 from mo.middle.passes.eliminate import graph_clean_up, remove_const_ops
32 from mo.middle.passes.infer import partial_infer
33 from mo.pipeline.common import prepare_emit_ir
34 from mo.utils import class_registration
35 from mo.utils.cli_parser import get_meta_info
36 from mo.utils.error import Error
37 from mo.utils.find_inputs import find_outputs
38 from mo.utils.utils import refer_to_faq_msg
41 def apply_biases_to_last_layer(graph, counts):
43 The idea is the following. If the user provides counts file, it is a file that contains log-apriory probabilities,
44 technically it should be subtracted from the bias of the last layer unless it is a SoftMax.
49 some layer ---> AffineTransform ---> SoftMax
51 Then, counts are applied to biases of Affine Transform:
54 (biases - counts) ---\
55 some layer ---> AffineTransform ---> SoftMax
60 some layer ---> AffineTransform
62 Just takes the last layer and updates biases:
65 (biases - counts) ---\
66 some layer ---> AffineTransform
77 outputs_ids = find_outputs(graph)
78 for output in outputs_ids.copy():
79 node = Node(graph, output)
80 if node.in_node().op != 'Memory':
82 outputs_ids.remove(output)
84 if len(outputs_ids) > 1:
85 raise Error('Ambiguity in applying counts to several outputs.')
86 elif len(outputs_ids) == 0:
87 raise Error('No outputs were found')
89 node = Node(graph, outputs_ids[0])
90 target_node = node.in_node()
91 if target_node and target_node['op'] == 'SoftMax':
92 data_node = target_node.in_node()
93 target_node = data_node.in_node()
95 biases_node = target_node.in_nodes()[2] # first - input, second - weights, third - biases
96 if biases_node.value is not None:
97 biases_node.value = np.subtract(biases_node.value, counts) # pylint: disable=assignment-from-no-return
99 biases_node.value = counts * -1
100 biases_node.shape = counts.shape
103 def driver(argv, input_model, output_model_name, output_dir):
104 meta_info = get_meta_info(argv)
106 EltwiseChecker.enabled = False
109 graph, input_shapes = load_kaldi_model(input_model)
110 except Exception as e:
111 raise Error('Model Optimizer is not able to read Kaldi model {}. '.format(input_model) +
112 refer_to_faq_msg(91)) from e
113 graph.check_empty_graph('load_kaldi_nnet_model')
114 graph.graph['cmd_params'] = argv
115 graph.graph['fw'] = 'kaldi'
116 graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5
117 update_extractors_with_extensions(kaldi_type_extractors)
118 extract_node_attrs(graph, lambda node: kaldi_extractor(node))
120 # --------------------------------- LOAD END ------------------------------------------------------
121 class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
123 graph = partial_infer(graph)
125 # The order is intentional, firstly eliminate repeated, then remove redundant
126 FuseRepeatedReshapes().find_and_replace_pattern(graph)
127 EliminateRedundantReshape().find_and_replace_pattern(graph)
128 graph.check_empty_graph('partial_infer')
131 counts = read_counts_file(argv.counts)
132 except Exception as e:
133 raise Error('Model Optimizer is not able to read counts file {}'.format(argv.counts) +
134 refer_to_faq_msg(92)) from e
136 apply_biases_to_last_layer(graph, counts)
138 if argv.remove_output_softmax:
139 RemoveLastSoftMaxPattern().find_and_replace_pattern(graph)
140 graph_clean_up(graph)
141 log.debug("After removing softmax")
142 graph.print_graph_stat()
144 # Intentionally after all transformations
145 KaldiRemoveMemoryOutputBackReplacementPattern().find_and_replace_pattern(graph)
147 remove_const_ops(graph)
148 CreateConstNodesReplacement().find_and_replace_pattern(graph)
150 remove_output_ops(graph)
152 prepare_emit_ir(graph, argv.data_type, output_dir, output_model_name, meta_info=meta_info)