Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / conv_test.py
index ad4e3aa..9b1fd73 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -18,7 +18,8 @@ import unittest
 
 import numpy as np
 
-from mo.middle.passes.conv import convert_muladd_to_scaleshift_or_power
+from mo.graph.graph import Node
+from mo.middle.passes.conv import convert_muladd_to_scaleshift_or_power, convert_add_or_mul_to_scaleshift
 from mo.middle.passes.eliminate import graph_clean_up
 from mo.utils.unittest.graph import build_graph, compare_graphs
 
@@ -27,19 +28,24 @@ nodes_attributes = {
     'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
     # ScaleShift layer
     'scaleshift_1': {'type': 'ScaleShift', 'value': None, 'kind': 'op', 'op': 'ScaleShift'},
+    'const_scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'op'},
     'scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'data'},
+    'const_scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'op'},
     'scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'data'},
     'scaleshift_1_data': {'value': None, 'shape': None, 'kind': 'data'},
     # Mul and Add operations
     'mul_1': {'value': None, 'kind': 'op', 'op': 'Mul'},
+    'const_mul_1_w': {'value': None, 'shape': None, 'kind': 'op'},
     'mul_1_w': {'value': None, 'shape': None, 'kind': 'data'},
     'mul_1_data': {'value': None, 'shape': None, 'kind': 'data'},
     'add_1': {'value': None, 'kind': 'op', 'op': 'Add'},
+    'const_add_1_w': {'value': None, 'shape': None, 'kind': 'op'},
     'add_1_w': {'value': None, 'shape': None, 'kind': 'data'},
     'add_1_data': {'value': None, 'shape': None, 'kind': 'data'},
     # Power layer
     'power_1': {'type': 'Power', 'kind': 'op', 'op': 'Power', 'scale': None, 'shift': None, 'power': None},
     'power_1_data': {'value': None, 'shape': None, 'kind': 'data'},
+    'op_output': {'kind': 'op', 'op': 'OpOutput'},
 }
 
 
@@ -48,17 +54,24 @@ class MulAddToScaleShiftOrPower(unittest.TestCase):
         graph = build_graph(nodes_attributes,
                             [('placeholder_1', 'placeholder_1_data'),
                              ('placeholder_1_data', 'mul_1'),
+                             ('const_mul_1_w', 'mul_1_w'),
                              ('mul_1_w', 'mul_1'),
                              ('mul_1', 'mul_1_data'),
                              ('mul_1_data', 'add_1'),
+                             ('const_add_1_w', 'add_1_w'),
                              ('add_1_w', 'add_1'),
                              ('add_1', 'add_1_data'),
+                             ('add_1_data', 'op_output')
                              ],
                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
-                             'add_1_data': {'shape': np.array([1, 227, 227, 3]), 'is_output': True},
+                             'add_1_data': {'shape': np.array([1, 227, 227, 3])},
+                             'const_mul_1_w': {'shape': np.array(mul_w.shape) if mul_w is not None else None,
+                                               'value': np.array(mul_w) if mul_w is not None else None},
                              'mul_1_w': {'shape': np.array(mul_w.shape) if mul_w is not None else None,
                                          'value': np.array(mul_w) if mul_w is not None else None},
+                             'const_add_1_w': {'shape': np.array(add_w.shape) if add_w is not None else None,
+                                               'value': np.array(add_w) if add_w is not None else None},
                              'add_1_w': {'shape': np.array(add_w.shape) if add_w is not None else None,
                                          'value': np.array(add_w) if add_w is not None else None},
                              })
@@ -72,13 +85,18 @@ class MulAddToScaleShiftOrPower(unittest.TestCase):
         graph_ref = build_graph(nodes_attributes,
                                 [('placeholder_1', 'placeholder_1_data'),
                                  ('placeholder_1_data', 'scaleshift_1'),
+                                 ('const_scaleshift_1_w', 'scaleshift_1_w'),
                                  ('scaleshift_1_w', 'scaleshift_1'),
+                                 ('const_scaleshift_1_b', 'scaleshift_1_b'),
                                  ('scaleshift_1_b', 'scaleshift_1'),
                                  ('scaleshift_1', 'scaleshift_1_data'),
+                                 ('scaleshift_1_data', 'op_output'),
                                  ],
-                                {'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
+                                {'const_scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
+                                 'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
+                                 'const_scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
                                  'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
-                                 'scaleshift_1_data': {'is_output': True}
+                                 'scaleshift_1_data': {}
                                  })
 
         convert_muladd_to_scaleshift_or_power(graph)
@@ -93,9 +111,10 @@ class MulAddToScaleShiftOrPower(unittest.TestCase):
                                 [('placeholder_1', 'placeholder_1_data'),
                                  ('placeholder_1_data', 'power_1'),
                                  ('power_1', 'power_1_data'),
+                                 ('power_1_data', 'op_output'),
                                  ],
                                 {'power_1': {'scale': 3, 'shift': 2, 'power': 1},
-                                 'power_1_data': {'is_output': True}
+                                 'power_1_data': {}
                                  })
 
         convert_muladd_to_scaleshift_or_power(graph)
@@ -144,13 +163,17 @@ class MulAddToScaleShiftOrPower(unittest.TestCase):
         graph_ref = build_graph(nodes_attributes,
                                 [('placeholder_1', 'placeholder_1_data'),
                                  ('placeholder_1_data', 'scaleshift_1'),
+                                 ('const_scaleshift_1_w', 'scaleshift_1_w'),
                                  ('scaleshift_1_w', 'scaleshift_1'),
+                                 ('const_scaleshift_1_b', 'scaleshift_1_b'),
                                  ('scaleshift_1_b', 'scaleshift_1'),
                                  ('scaleshift_1', 'add_1_data'),
+                                 ('add_1_data', 'op_output'),
                                  ],
-                                {'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([3, 3, 3])},
+                                {'const_scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([3, 3, 3])},
+                                 'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([3, 3, 3])},
+                                 'const_scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([3, 2, 1])},
                                  'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([3, 2, 1])},
-                                 'add_1_data': {'is_output': True}
                                  })
 
         convert_muladd_to_scaleshift_or_power(graph)
@@ -159,5 +182,118 @@ class MulAddToScaleShiftOrPower(unittest.TestCase):
         self.assertTrue(flag, resp)
 
 
+class AddToScaleShift(unittest.TestCase):
+    @staticmethod
+    def _create_graph_with_add(add_w: np.ndarray):
+        graph = build_graph(nodes_attributes,
+                            [('placeholder_1', 'placeholder_1_data'),
+                             ('placeholder_1_data', 'add_1'),
+                             ('const_add_1_w', 'add_1_w'),
+                             ('add_1_w', 'add_1'),
+                             ('add_1', 'add_1_data'),
+                             ('add_1_data', 'op_output')
+                             ],
+                            {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
+                             'add_1_data': {'shape': np.array([1, 227, 227, 3])},
+                             'const_add_1_w': {'shape': np.array(add_w.shape) if add_w is not None else None,
+                                               'value': np.array(add_w) if add_w is not None else None},
+                             'add_1_w': {'shape': np.array(add_w.shape) if add_w is not None else None,
+                                         'value': np.array(add_w) if add_w is not None else None},
+                             }, nodes_with_edges_only=True)
+        del graph['add_1']['add_1_data'][0]['in']
+        return graph
+
+    @staticmethod
+    def _create_graph_with_mul(mul_w: np.ndarray):
+        graph = build_graph(nodes_attributes,
+                            [('placeholder_1', 'placeholder_1_data'),
+                             ('placeholder_1_data', 'mul_1'),
+                             ('const_mul_1_w', 'mul_1_w'),
+                             ('mul_1_w', 'mul_1'),
+                             ('mul_1', 'mul_1_data'),
+                             ('mul_1_data', 'op_output')
+                             ],
+                            {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
+                             'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
+                             'const_mul_1_w': {'shape': np.array(mul_w.shape) if mul_w is not None else None,
+                                               'value': np.array(mul_w) if mul_w is not None else None},
+                             'mul_1_w': {'shape': np.array(mul_w.shape) if mul_w is not None else None,
+                                         'value': np.array(mul_w) if mul_w is not None else None},
+                             }, nodes_with_edges_only=True)
+        del graph['mul_1']['mul_1_data'][0]['in']
+        return graph
+
+    def test_add_to_scaleshift_1(self):
+        graph = AddToScaleShift._create_graph_with_add(np.array([1, 2, 3], dtype=np.float32))
+        graph.stage = 'middle'
+
+        graph_ref = build_graph(nodes_attributes,
+                            [('placeholder_1', 'placeholder_1_data'),
+                             ('placeholder_1_data', 'scaleshift_1'),
+                             ('const_scaleshift_1_w', 'scaleshift_1_w'),
+                             ('const_scaleshift_1_b', 'scaleshift_1_b'),
+                             ('scaleshift_1_w', 'scaleshift_1'),
+                             ('scaleshift_1_b', 'scaleshift_1'),
+                             ('scaleshift_1', 'scaleshift_1_data'),
+                             ('scaleshift_1_data', 'op_output')
+                             ],
+                            {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
+                             'scaleshift_1_data': {'shape': np.array([1, 227, 227, 3])},
+
+                             'const_scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 1, 1])},
+                             'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 1, 1])},
+
+                             'const_scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
+                             'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
+                             }, nodes_with_edges_only=True)
+
+        convert_add_or_mul_to_scaleshift(graph)
+        graph_clean_up(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'op_output')
+        self.assertTrue(flag, resp)
+
+        scsh_node = Node(graph, 'op_output').in_port(0).get_source().node
+
+        self.assertTrue(graph.get_edge_data(scsh_node.in_node(1).id, scsh_node.id)[0]['bin'] == 'weights')
+        self.assertTrue(graph.get_edge_data(scsh_node.in_node(2).id, scsh_node.id)[0]['bin'] == 'biases')
+
+    def test_mul_to_scaleshift_1(self):
+        graph = AddToScaleShift._create_graph_with_mul(np.array([1, 2, 3], dtype=np.float32))
+        graph.stage = 'middle'
+
+        graph_ref = build_graph(nodes_attributes,
+                            [('placeholder_1', 'placeholder_1_data'),
+                             ('placeholder_1_data', 'scaleshift_1'),
+                             ('const_scaleshift_1_w', 'scaleshift_1_w'),
+                             ('const_scaleshift_1_b', 'scaleshift_1_b'),
+                             ('scaleshift_1_w', 'scaleshift_1'),
+                             ('scaleshift_1_b', 'scaleshift_1'),
+                             ('scaleshift_1', 'scaleshift_1_data'),
+                             ('scaleshift_1_data', 'op_output')
+                             ],
+                            {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
+                             'scaleshift_1_data': {'shape': np.array([1, 227, 227, 3])},
+
+                             'const_scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
+                             'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
+
+                             'const_scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([0, 0, 0])},
+                             'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([0, 0, 0])},
+                             }, nodes_with_edges_only=True)
+
+        convert_add_or_mul_to_scaleshift(graph)
+        graph_clean_up(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'op_output')
+        self.assertTrue(flag, resp)
+
+        scsh_node = Node(graph, 'op_output').in_port(0).get_source().node
+
+        self.assertTrue(graph.get_edge_data(scsh_node.in_node(1).id, scsh_node.id)[0]['bin'] == 'weights')
+        self.assertTrue(graph.get_edge_data(scsh_node.in_node(2).id, scsh_node.id)[0]['bin'] == 'biases')
+
+
+
 if __name__ == '__main__':
     unittest.main()