Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / AddMeanScaleValues.py
1 """
2  Copyright (c) 2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import numpy as np
17
18 from mo.graph.graph import Graph, Node
19 from mo.middle.replacement import MiddleReplacementPattern
20 from mo.ops.lin_op import Add, Mul
21 from mo.ops.op import Op
22 from mo.utils.error import Error
23 from mo.utils.utils import refer_to_faq_msg
24
25
26 class AddMeanScaleValues(MiddleReplacementPattern):
27     enabled = True
28
29     def run_after(self):
30         return []
31
32     def run_before(self):
33         from extensions.middle.pass_separator import MiddleStart
34         return [MiddleStart]
35
36     @staticmethod
37     def apply_scale(graph: Graph, input_node: Node, node_mean_scale_values: dict):
38         if 'scale' in node_mean_scale_values and node_mean_scale_values['scale'] is not None:
39             if all([x == 1 for x in node_mean_scale_values['scale']]):
40                 return
41             out_node = input_node.out_node()
42             if not input_node.has_valid('shape'):
43                 raise Error("Node {} has not valid shape attribute".format(input_node.id))
44             input_shape = input_node.shape
45
46             # Create Mul node
47             value = 1 / np.array(node_mean_scale_values['scale'])
48             graph.remove_edge(input_node.id, out_node.id)
49
50             mul_node = Mul(graph, dict(name="Mul_"))
51             mul_data = Op.create_input_data_node(graph, "data_mul_", np.array(value))
52             Op.expand_node_shape(mul_data, (len(input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 0))
53             mul_input = Op.create_data_node(graph, input_node, {'shape': out_node.shape})
54
55             mul_node.create_node_with_data(inputs=[mul_input, mul_data], data_nodes=out_node)
56
57     @staticmethod
58     def apply_mean_value(graph: Graph, input_node: Node, node_mean_scale_values: dict):
59         if 'mean' in node_mean_scale_values and node_mean_scale_values['mean'] is not None:
60             if all([x == 0 for x in node_mean_scale_values['mean']]):
61                 return
62             out_node = input_node.out_node()
63             if not input_node.has_valid('shape'):
64                 raise Error("Node {} has not valid shape attribute".format(input_node.id))
65             input_shape = input_node.shape
66             # Create Add node
67             graph.remove_edge(input_node.id, out_node.id)
68
69             value = np.array(node_mean_scale_values['mean']) * (-1)
70
71             add_node = Add(graph, dict(name="Add_"))
72             add_data = Op.create_input_data_node(graph, "data_add_", np.array(value))
73             Op.expand_node_shape(add_data, (len(input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 0))
74             add_input = Op.create_data_node(graph, input_node, {'shape': out_node.shape})
75
76             add_node.create_node_with_data(inputs=[add_input, add_data], data_nodes=out_node)
77
78     def find_and_replace_pattern(self, graph: Graph):
79         input_nodes = {}
80         values = graph.graph['cmd_params'].mean_scale_values
81         for node in graph.nodes():
82             node = Node(graph, node)
83             if node.has_valid('op') and node.op == 'Placeholder':
84                 input_nodes.update({node.id: node})
85
86         if not isinstance(values, dict):
87             if len(values) != len(input_nodes):
88                 raise Error('Numbers of inputs and mean/scale values do not match. ' +
89                             refer_to_faq_msg(61))
90
91             data = np.copy(values)
92             values = {}
93             for idx, key in enumerate(input_nodes.keys()):
94                 values.update(
95                     {
96                         input_nodes[key]['name']: {
97                             'mean': data[idx][0],
98                             'scale': data[idx][1]
99                         }
100                     }
101                 )
102
103         for node_name in values:
104             node_id = graph.get_node_id_by_name(node_name)
105             node_mean_scale_values = values[node_name]
106             if node_id not in input_nodes:
107                 # if the user cutted-off input of the network then input node name specified in the --scale_values
108                 # or --mean_values doesn't correspond to a real input node generated by Model Optimizer. But
109                 # the information about initial input node name is stored in Placeholder's attribute 'initial_node_name'
110                 new_node_id = None
111                 for placeholder in input_nodes.values():
112                     if placeholder.has('initial_node_name') and placeholder.initial_node_name == node_name:
113                         new_node_id = placeholder.id
114                         break
115                 if new_node_id is None:
116                     raise Error('Input with name {} wasn\'t found!'.format(node_name) +
117                                 refer_to_faq_msg(83))
118                 node_id = new_node_id
119
120             input_node = Node(graph, node_id)
121             AddMeanScaleValues.apply_scale(graph, input_node, node_mean_scale_values)
122             AddMeanScaleValues.apply_mean_value(graph, input_node, node_mean_scale_values)