"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import numpy as np
-from mo.middle.passes.conv import convert_muladd_to_scaleshift_or_power
+from mo.graph.graph import Node
+from mo.middle.passes.conv import convert_muladd_to_scaleshift_or_power, convert_add_or_mul_to_scaleshift
from mo.middle.passes.eliminate import graph_clean_up
from mo.utils.unittest.graph import build_graph, compare_graphs
'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
# ScaleShift layer
'scaleshift_1': {'type': 'ScaleShift', 'value': None, 'kind': 'op', 'op': 'ScaleShift'},
+ 'const_scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'op'},
'scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'data'},
+ 'const_scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'op'},
'scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'data'},
'scaleshift_1_data': {'value': None, 'shape': None, 'kind': 'data'},
# Mul and Add operations
'mul_1': {'value': None, 'kind': 'op', 'op': 'Mul'},
+ 'const_mul_1_w': {'value': None, 'shape': None, 'kind': 'op'},
'mul_1_w': {'value': None, 'shape': None, 'kind': 'data'},
'mul_1_data': {'value': None, 'shape': None, 'kind': 'data'},
'add_1': {'value': None, 'kind': 'op', 'op': 'Add'},
+ 'const_add_1_w': {'value': None, 'shape': None, 'kind': 'op'},
'add_1_w': {'value': None, 'shape': None, 'kind': 'data'},
'add_1_data': {'value': None, 'shape': None, 'kind': 'data'},
# Power layer
'power_1': {'type': 'Power', 'kind': 'op', 'op': 'Power', 'scale': None, 'shift': None, 'power': None},
'power_1_data': {'value': None, 'shape': None, 'kind': 'data'},
+ 'op_output': {'kind': 'op', 'op': 'OpOutput'},
}
graph = build_graph(nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'mul_1'),
+ ('const_mul_1_w', 'mul_1_w'),
('mul_1_w', 'mul_1'),
('mul_1', 'mul_1_data'),
('mul_1_data', 'add_1'),
+ ('const_add_1_w', 'add_1_w'),
('add_1_w', 'add_1'),
('add_1', 'add_1_data'),
+ ('add_1_data', 'op_output')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
- 'add_1_data': {'shape': np.array([1, 227, 227, 3]), 'is_output': True},
+ 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
+ 'const_mul_1_w': {'shape': np.array(mul_w.shape) if mul_w is not None else None,
+ 'value': np.array(mul_w) if mul_w is not None else None},
'mul_1_w': {'shape': np.array(mul_w.shape) if mul_w is not None else None,
'value': np.array(mul_w) if mul_w is not None else None},
+ 'const_add_1_w': {'shape': np.array(add_w.shape) if add_w is not None else None,
+ 'value': np.array(add_w) if add_w is not None else None},
'add_1_w': {'shape': np.array(add_w.shape) if add_w is not None else None,
'value': np.array(add_w) if add_w is not None else None},
})
graph_ref = build_graph(nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'scaleshift_1'),
+ ('const_scaleshift_1_w', 'scaleshift_1_w'),
('scaleshift_1_w', 'scaleshift_1'),
+ ('const_scaleshift_1_b', 'scaleshift_1_b'),
('scaleshift_1_b', 'scaleshift_1'),
('scaleshift_1', 'scaleshift_1_data'),
+ ('scaleshift_1_data', 'op_output'),
],
- {'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
+ {'const_scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
+ 'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
+ 'const_scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
- 'scaleshift_1_data': {'is_output': True}
+ 'scaleshift_1_data': {}
})
convert_muladd_to_scaleshift_or_power(graph)
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'power_1'),
('power_1', 'power_1_data'),
+ ('power_1_data', 'op_output'),
],
{'power_1': {'scale': 3, 'shift': 2, 'power': 1},
- 'power_1_data': {'is_output': True}
+ 'power_1_data': {}
})
convert_muladd_to_scaleshift_or_power(graph)
graph_ref = build_graph(nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'scaleshift_1'),
+ ('const_scaleshift_1_w', 'scaleshift_1_w'),
('scaleshift_1_w', 'scaleshift_1'),
+ ('const_scaleshift_1_b', 'scaleshift_1_b'),
('scaleshift_1_b', 'scaleshift_1'),
('scaleshift_1', 'add_1_data'),
+ ('add_1_data', 'op_output'),
],
- {'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([3, 3, 3])},
+ {'const_scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([3, 3, 3])},
+ 'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([3, 3, 3])},
+ 'const_scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([3, 2, 1])},
'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([3, 2, 1])},
- 'add_1_data': {'is_output': True}
})
convert_muladd_to_scaleshift_or_power(graph)
self.assertTrue(flag, resp)
+class AddToScaleShift(unittest.TestCase):
+ @staticmethod
+ def _create_graph_with_add(add_w: np.ndarray):
+ graph = build_graph(nodes_attributes,
+ [('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'add_1'),
+ ('const_add_1_w', 'add_1_w'),
+ ('add_1_w', 'add_1'),
+ ('add_1', 'add_1_data'),
+ ('add_1_data', 'op_output')
+ ],
+ {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
+ 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
+ 'const_add_1_w': {'shape': np.array(add_w.shape) if add_w is not None else None,
+ 'value': np.array(add_w) if add_w is not None else None},
+ 'add_1_w': {'shape': np.array(add_w.shape) if add_w is not None else None,
+ 'value': np.array(add_w) if add_w is not None else None},
+ }, nodes_with_edges_only=True)
+ del graph['add_1']['add_1_data'][0]['in']
+ return graph
+
+ @staticmethod
+ def _create_graph_with_mul(mul_w: np.ndarray):
+ graph = build_graph(nodes_attributes,
+ [('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'mul_1'),
+ ('const_mul_1_w', 'mul_1_w'),
+ ('mul_1_w', 'mul_1'),
+ ('mul_1', 'mul_1_data'),
+ ('mul_1_data', 'op_output')
+ ],
+ {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
+ 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
+ 'const_mul_1_w': {'shape': np.array(mul_w.shape) if mul_w is not None else None,
+ 'value': np.array(mul_w) if mul_w is not None else None},
+ 'mul_1_w': {'shape': np.array(mul_w.shape) if mul_w is not None else None,
+ 'value': np.array(mul_w) if mul_w is not None else None},
+ }, nodes_with_edges_only=True)
+ del graph['mul_1']['mul_1_data'][0]['in']
+ return graph
+
+ def test_add_to_scaleshift_1(self):
+ graph = AddToScaleShift._create_graph_with_add(np.array([1, 2, 3], dtype=np.float32))
+ graph.stage = 'middle'
+
+ graph_ref = build_graph(nodes_attributes,
+ [('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'scaleshift_1'),
+ ('const_scaleshift_1_w', 'scaleshift_1_w'),
+ ('const_scaleshift_1_b', 'scaleshift_1_b'),
+ ('scaleshift_1_w', 'scaleshift_1'),
+ ('scaleshift_1_b', 'scaleshift_1'),
+ ('scaleshift_1', 'scaleshift_1_data'),
+ ('scaleshift_1_data', 'op_output')
+ ],
+ {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
+ 'scaleshift_1_data': {'shape': np.array([1, 227, 227, 3])},
+
+ 'const_scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 1, 1])},
+ 'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 1, 1])},
+
+ 'const_scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
+ 'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
+ }, nodes_with_edges_only=True)
+
+ convert_add_or_mul_to_scaleshift(graph)
+ graph_clean_up(graph)
+
+ (flag, resp) = compare_graphs(graph, graph_ref, 'op_output')
+ self.assertTrue(flag, resp)
+
+ scsh_node = Node(graph, 'op_output').in_port(0).get_source().node
+
+ self.assertTrue(graph.get_edge_data(scsh_node.in_node(1).id, scsh_node.id)[0]['bin'] == 'weights')
+ self.assertTrue(graph.get_edge_data(scsh_node.in_node(2).id, scsh_node.id)[0]['bin'] == 'biases')
+
+ def test_mul_to_scaleshift_1(self):
+ graph = AddToScaleShift._create_graph_with_mul(np.array([1, 2, 3], dtype=np.float32))
+ graph.stage = 'middle'
+
+ graph_ref = build_graph(nodes_attributes,
+ [('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'scaleshift_1'),
+ ('const_scaleshift_1_w', 'scaleshift_1_w'),
+ ('const_scaleshift_1_b', 'scaleshift_1_b'),
+ ('scaleshift_1_w', 'scaleshift_1'),
+ ('scaleshift_1_b', 'scaleshift_1'),
+ ('scaleshift_1', 'scaleshift_1_data'),
+ ('scaleshift_1_data', 'op_output')
+ ],
+ {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
+ 'scaleshift_1_data': {'shape': np.array([1, 227, 227, 3])},
+
+ 'const_scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
+ 'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
+
+ 'const_scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([0, 0, 0])},
+ 'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([0, 0, 0])},
+ }, nodes_with_edges_only=True)
+
+ convert_add_or_mul_to_scaleshift(graph)
+ graph_clean_up(graph)
+
+ (flag, resp) = compare_graphs(graph, graph_ref, 'op_output')
+ self.assertTrue(flag, resp)
+
+ scsh_node = Node(graph, 'op_output').in_port(0).get_source().node
+
+ self.assertTrue(graph.get_edge_data(scsh_node.in_node(1).id, scsh_node.id)[0]['bin'] == 'weights')
+ self.assertTrue(graph.get_edge_data(scsh_node.in_node(2).id, scsh_node.id)[0]['bin'] == 'biases')
+
+
+
if __name__ == '__main__':
unittest.main()