2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from mo.graph.graph import Graph
20 from mo.middle.pattern_match import apply_pattern
23 def move_scaleshift_to_preprocess_action(graph, match):
25 input_op = match['input_op']
26 scale_shift = match['scale_shift']
27 weights = np.squeeze(match['weights'].value)
28 biases = np.squeeze(match['biases'].value)
30 if any([x != 1 for x in weights]):
33 # Keep biases (mean values) for current input as graph attr and remove ScaleShift layer
34 # Input->data->ScaleShift->scsh_data => Input->scsh_data
35 graph.remove_edge(input_op.id, input_op.out_node().id)
36 graph.add_edge(input_op.id, scale_shift.out_node().id, out=0)
37 graph.remove_edge(scale_shift.id, scale_shift.out_node().id)
39 # If bias contains zeros we just remove it
40 if all([x == 0 for x in biases]):
43 # In pre-process section, mean_values are subtracted
46 mean_values.update({input_op.name: np.array(biases)})
48 # Add graph attribute 'mean_values' that stores mean_values per input if exists
49 if graph.graph.get('mean_values', None):
50 graph.graph['mean_values'].update(mean_values)
52 graph.graph['mean_values'] = mean_values
55 def move_scaleshift_to_preprocess(graph: Graph):
57 This function finds scaleshift layer after input layer and if it has weights with ones, it deletes scaleshift layer
58 and creates graph dict attribute : {'input':np.array(...), 'input2': ... }
63 ('weights', dict(kind='data')),
64 ('biases', dict(kind='data')),
65 ('input_output', dict(kind='data')),
66 ('scsh_output', dict(kind='data')),
67 ('input_op', dict(kind='op', type='Input')),
68 ('scale_shift', dict(kind='op', type='ScaleShift')),
71 ('input_op', 'input_output'),
72 ('scale_shift', 'scsh_output'),
73 ('input_output', 'scale_shift', {'in': 0}),
74 ('weights', 'scale_shift', {'in': 1}),
75 ('biases', 'scale_shift', {'in': 2}),
77 action=move_scaleshift_to_preprocess_action