2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from mo.graph.graph import Graph, Node
19 from mo.middle.replacement import MiddleReplacementPattern
20 from mo.ops.lin_op import Add, Mul
21 from mo.ops.op import Op
22 from mo.utils.error import Error
23 from mo.utils.utils import refer_to_faq_msg
26 class AddMeanScaleValues(MiddleReplacementPattern):
33 from extensions.middle.pass_separator import MiddleStart
37 def apply_scale(graph: Graph, input_node: Node, node_mean_scale_values: dict):
38 if 'scale' in node_mean_scale_values and node_mean_scale_values['scale'] is not None:
39 if all([x == 1 for x in node_mean_scale_values['scale']]):
41 out_node = input_node.out_node()
42 if not input_node.has_valid('shape'):
43 raise Error("Node {} has not valid shape attribute".format(input_node.id))
44 input_shape = input_node.shape
47 value = 1 / np.array(node_mean_scale_values['scale'])
48 graph.remove_edge(input_node.id, out_node.id)
50 mul_node = Mul(graph, dict(name="Mul_"))
51 mul_data = Op.create_input_data_node(graph, "data_mul_", np.array(value))
52 Op.expand_node_shape(mul_data, (len(input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 0))
53 mul_input = Op.create_data_node(graph, input_node, {'shape': out_node.shape})
55 mul_node.create_node_with_data(inputs=[mul_input, mul_data], data_nodes=out_node)
58 def apply_mean_value(graph: Graph, input_node: Node, node_mean_scale_values: dict):
59 if 'mean' in node_mean_scale_values and node_mean_scale_values['mean'] is not None:
60 if all([x == 0 for x in node_mean_scale_values['mean']]):
62 out_node = input_node.out_node()
63 if not input_node.has_valid('shape'):
64 raise Error("Node {} has not valid shape attribute".format(input_node.id))
65 input_shape = input_node.shape
67 graph.remove_edge(input_node.id, out_node.id)
69 value = np.array(node_mean_scale_values['mean']) * (-1)
71 add_node = Add(graph, dict(name="Add_"))
72 add_data = Op.create_input_data_node(graph, "data_add_", np.array(value))
73 Op.expand_node_shape(add_data, (len(input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 0))
74 add_input = Op.create_data_node(graph, input_node, {'shape': out_node.shape})
76 add_node.create_node_with_data(inputs=[add_input, add_data], data_nodes=out_node)
78 def find_and_replace_pattern(self, graph: Graph):
80 values = graph.graph['cmd_params'].mean_scale_values
81 for node in graph.nodes():
82 node = Node(graph, node)
83 if node.has_valid('op') and node.op == 'Placeholder':
84 input_nodes.update({node.id: node})
86 if not isinstance(values, dict):
87 if len(values) != len(input_nodes):
88 raise Error('Numbers of inputs and mean/scale values do not match. ' +
91 data = np.copy(values)
93 for idx, key in enumerate(input_nodes.keys()):
96 input_nodes[key]['name']: {
103 for node_name in values:
104 node_id = graph.get_node_id_by_name(node_name)
105 node_mean_scale_values = values[node_name]
106 if node_id not in input_nodes:
107 # if the user cutted-off input of the network then input node name specified in the --scale_values
108 # or --mean_values doesn't correspond to a real input node generated by Model Optimizer. But
109 # the information about initial input node name is stored in Placeholder's attribute 'initial_node_name'
111 for placeholder in input_nodes.values():
112 if placeholder.has('initial_node_name') and placeholder.initial_node_name == node_name:
113 new_node_id = placeholder.id
115 if new_node_id is None:
116 raise Error('Input with name {} wasn\'t found!'.format(node_name) +
117 refer_to_faq_msg(83))
118 node_id = new_node_id
120 input_node = Node(graph, node_id)
121 AddMeanScaleValues.apply_scale(graph, input_node, node_mean_scale_values)
122 AddMeanScaleValues.apply_mean_value(graph, input_node, node_mean_scale_values)