Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / ops / flatten_test.py
index 9d58401..75de344 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -24,7 +24,8 @@ from mo.utils.unittest.graph import build_graph
 
 nodes_attributes = {'node_1': {'value': None, 'kind': 'data'},
                     'flatten_1': {'type': 'Flatten', 'value': None, 'kind': 'op'},
-                    'node_2': {'value': None, 'kind': 'data'}
+                    'node_2': {'value': None, 'kind': 'data'},
+                    'output_op': { 'kind': 'op', 'op': 'OpOutput'},
                     }
 
 
@@ -32,8 +33,10 @@ class TestFlattenPartialInfer(unittest.TestCase):
     def test_flatten_infer(self):
         graph = build_graph(nodes_attributes,
                             [('node_1', 'flatten_1'),
-                             ('flatten_1', 'node_2')],
-                            {'node_2': {'is_output': True, 'shape': np.array([1, 3 * 256 * 256])},
+                             ('flatten_1', 'node_2'),
+                             ('node_2', 'op_output')
+                             ],
+                            {'node_2': {'shape': np.array([1, 3 * 256 * 256])},
                              'node_1': {'shape': np.array([1, 3, 256, 256])},
                              'flatten_1': {'axis': 1, 'dim': []}
                              })
@@ -49,8 +52,10 @@ class TestFlattenPartialInfer(unittest.TestCase):
     def test_flatten_infer_no_shape(self):
         graph = build_graph(nodes_attributes,
                             [('node_1', 'flatten_1'),
-                             ('flatten_1', 'node_2')],
-                            {'node_2': {'is_output': True, 'shape': None},
+                             ('flatten_1', 'node_2'),
+                             ('node_2', 'op_output')
+                             ],
+                            {'node_2': {'shape': None},
                              'node_1': {'shape': None},
                              'flatten_1': {'axis': 1}
                              })